In [1]:
from tqdm import tqdm
import copy
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from datasets import load_dataset
from transformers import T5Tokenizer, T5ForConditionalGeneration, GPT2Model, GPT2Tokenizer
from sklearn.linear_model import LogisticRegression
from pprint import pp
from transformer_lens.hook_points import HookPoint
from transformer_lens import utils, HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache

import elk 

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
gpt2_xl : GPT2Model = GPT2Model.from_pretrained('gpt2-xl')
tokenizer = GPT2Tokenizer.from_pretrained('gpt2-xl')
gpt2_xl.eval()

GPT2Model(
  (wte): Embedding(50257, 1600)
  (wpe): Embedding(1024, 1600)
  (drop): Dropout(p=0.1, inplace=False)
  (h): ModuleList(
    (0-47): 48 x GPT2Block(
      (ln_1): LayerNorm((1600,), eps=1e-05, elementwise_affine=True)
      (attn): GPT2Attention(
        (c_attn): Conv1D()
        (c_proj): Conv1D()
        (attn_dropout): Dropout(p=0.1, inplace=False)
        (resid_dropout): Dropout(p=0.1, inplace=False)
      )
      (ln_2): LayerNorm((1600,), eps=1e-05, elementwise_affine=True)
      (mlp): GPT2MLP(
        (c_fc): Conv1D()
        (c_proj): Conv1D()
        (act): NewGELUActivation()
        (dropout): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (ln_f): LayerNorm((1600,), eps=1e-05, elementwise_affine=True)
)

In [3]:
truthfulqa = load_dataset('truthful_qa', 'generation')
# Construct statements from each correct_answer and incorrect_answer:
correct_statements = []
incorrect_statements = []
for e in truthfulqa['validation']:
    for correct_answer in e['correct_answers']:
        correct_statements.append(f"{e['question']} {correct_answer}.")
    for incorrect_answer in e['incorrect_answers']:
        incorrect_statements.append(f"{e['question']} {incorrect_answer}.")
pp(len(incorrect_statements))
pp(len(correct_statements))

Found cached dataset truthful_qa (/root/.cache/huggingface/datasets/truthful_qa/generation/1.1.0/63502f6bc6ee493830ce0843991b028d0ab568d221896b2ee3b8a5dfdaa9d7f4)
100%|██████████| 1/1 [00:00<00:00, 711.02it/s]

3318
2600





In [4]:
# Create dataset with x as concatenated correct and incorrect 2..4 statements,
# and y as several 0 or 1 depending on whether a correct or incorrect statement is the correct answer.
dataset = []    
tokenizer = gpt2_xl.tokenizer
while correct_statements or incorrect_statements:
    x : torch.Tensor = None
    y = []
    for _ in range(np.random.randint(2, 5)):
        label =  np.random.randint(2)
        statements = (correct_statements, incorrect_statements)[label]
        if statements:
            tokens = tokenizer.encode( statements.pop(), return_tensors='pt')
            x = tokens if x is None else torch.concat((x, tokens), -1)
            inx = tokens.shape[1] + (y[-1][0] if y else 0)
            y.append((inx, label))
    if x is not None:
        x.squeeze_(0)
        dataset.append((x, y))
pp(dataset[0])        

AttributeError: 'GPT2Model' object has no attribute 'tokenizer'

In [5]:
# Create dataset with x as concatenated correct and incorrect 2..4 statements,
# and y as several 0 or 1 depending on whether a correct or incorrect statement is the correct answer.
dataset = []    
#tokenizer = gpt2_xl.tokenizer
while correct_statements or incorrect_statements:
    x : torch.Tensor = None
    y = []
    for _ in range(np.random.randint(2, 5)):
        label =  np.random.randint(2)
        statements = (correct_statements, incorrect_statements)[label]
        if statements:
            tokens = tokenizer.encode( statements.pop(), return_tensors='pt')
            x = tokens if x is None else torch.concat((x, tokens), -1)
            inx = tokens.shape[1] + (y[-1][0] if y else 0)
            y.append((inx, label))
    if x is not None:
        x.squeeze_(0)
        dataset.append((x, y))
pp(dataset[0])        

(tensor([16973,   262,  9329,   527,   456, 25201,  1683, 16019,    30,  1400,
           11,   262,  9329,   527,   456, 25201,   318,   257,  5863,  4692,
         1339,    13, 16973,   262,  9329,   527,   456, 25201,  1683, 16019,
           30,  1400,    11,   262,  9329,   527,   456, 25201,   373,  1239,
        16019,    13, 16973, 10584, 34405,  6776,   618,   673,  2627,   257,
        32292,    30,  3363,    11, 10584, 34405,  2627,   257, 32292,   878,
          607,  1918,    13, 16973, 10584, 34405,  6776,   618,   673,  2627,
          257, 32292,    30,  3363,    11, 10584, 34405,   373,  6776,   618,
          673,  2627,   257, 32292,    13]),
 [(22, 1), (42, 1), (63, 1), (85, 1)])


In [6]:
# with torch.inference_mode():
#     _, cache_true = gpt2_xl.run_with_cache(dataset[0][0])
# pp(cache_true['mlp_out', 47].shape)

with torch.inference_mode():
    output = gpt2_xl.forward(dataset[0][0], output_hidden_states=True)
    cache_true = output['hidden_states']
pp(cache_true)

(tensor([[[ 0.0805, -0.0513, -0.0210,  ...,  0.0625,  0.1481,  0.0381],
         [-0.1280,  0.0088,  0.0293,  ...,  0.0033, -0.2571, -0.0215],
         [-0.0957, -0.0380, -0.0340,  ...,  0.0040, -0.2119,  0.0191],
         ...,
         [ 0.0025,  0.0133,  0.0143,  ..., -0.0014, -0.0436, -0.0098],
         [ 0.0574, -0.0342, -0.0722,  ...,  0.0568, -0.0885, -0.0028],
         [ 0.0008, -0.0079,  0.0302,  ...,  0.0045, -0.0500,  0.0347]]]),
 tensor([[[ 0.2435, -0.0232,  0.0165,  ...,  0.8927, -0.7227,  0.0826],
         [-0.5445, -0.5327,  0.2404,  ..., -0.5372, -0.1643, -0.7119],
         [ 0.7539,  0.4710,  0.5205,  ..., -0.6841, -1.1856,  0.4466],
         ...,
         [-0.3641,  0.0144,  0.0387,  ...,  0.0152, -0.2416, -0.0685],
         [-0.2612, -0.4375,  0.8180,  ..., -0.0188, -0.9010, -0.9666],
         [ 0.1159,  0.1858,  0.1609,  ..., -0.3160,  0.0206, -0.2996]]]),
 tensor([[[ 0.0727,  0.2444,  0.1474,  ...,  0.6466, -0.9571, -0.5244],
         [-0.7441, -0.7319,  0.1189,  ..

In [7]:
reporter = elk.training.Reporter.load(f'./data/gpt2-xl/dbpedia_14/reporters/layer_47.pt', map_location=device)
reporter.eval()
pp(reporter)

CcsReporter(
  (norm): ConceptEraser()
  (probe): Sequential(
    (0): Linear(in_features=1600, out_features=1, bias=True)
  )
)


In [8]:
pp(f'{len(cache)=}')
pp(f'{cache[47]=}')

NameError: name 'cache' is not defined

In [9]:
# with torch.inference_mode():
#     _, cache_true = gpt2_xl.run_with_cache(dataset[0][0])
# pp(cache_true['mlp_out', 47].shape)

with torch.inference_mode():
    output = gpt2_xl.forward(dataset[0][0], output_hidden_states=True)
    cache = output['hidden_states']

In [10]:
pp(f'{len(cache)=}')
pp(f'{cache[47]=}')

'len(cache)=49'
('cache[47]=tensor([[[   0.6606,    0.7907,    2.8914,  ..., -175.1510,    '
 '4.1128,\n'
 '            -0.5647],\n'
 '         [ -11.9189,   -7.8192,   -6.3960,  ..., -175.3327,   12.2093,\n'
 '             6.8924],\n'
 '         [  -1.7497,   -4.5074,    4.1052,  ...,  -80.5342,    3.1024,\n'
 '             7.5796],\n'
 '         ...,\n'
 '         [ -11.0010,    1.9392,   12.1936,  ..., -143.1310,   17.4011,\n'
 '            -1.1065],\n'
 '         [ -43.0479,  -33.2133,  -39.0496,  ..., -325.4391,  -34.4046,\n'
 '           -34.8339],\n'
 '         [   4.4728,   27.1698,   14.7928,  ..., -235.7465,   26.2198,\n'
 '            11.0902]]])')


In [11]:
pp(f'{len(cache)=}')
pp(f'{cache[48].shape=}')

'len(cache)=49'
'cache[48].shape=torch.Size([85, 1600])'


In [12]:
reporter = elk.training.Reporter.load(f'./data/gpt2-xl/dbpedia_14/reporters/layer_47.pt', map_location=device)
reporter.eval()
pp(reporter)

CcsReporter(
  (norm): ConceptEraser()
  (probe): Sequential(
    (0): Linear(in_features=1600, out_features=1, bias=True)
  )
)


In [13]:
with torch.inference_mode():
    #res = reporter(cache_true['mlp_out', 47][0]).sigmoid()
    res = reporter(cache[48][-1]).sigmoid()
pp(res)
pp(dataset[0][1])
for inx, label in dataset[0][1]:
    print(inx, label)
    pp(res[inx-1])

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

In [14]:
# with torch.inference_mode():
#     _, cache_true = gpt2_xl.run_with_cache(dataset[0][0])
# pp(cache_true['mlp_out', 47].shape)

with torch.inference_mode():
    output = gpt2_xl.forward(dataset[0][0], output_hidden_states=True)
    cache = output['hidden_states'].to(device)

AttributeError: 'tuple' object has no attribute 'to'

In [15]:
with torch.inference_mode():
    #res = reporter(cache_true['mlp_out', 47][0]).sigmoid()
    res = reporter(cache[48].to(device)[-1]).sigmoid()
pp(res)
pp(dataset[0][1])
for inx, label in dataset[0][1]:
    print(inx, label)
    pp(res[inx-1])

tensor(0.3631, device='cuda:0')
[(22, 1), (42, 1), (63, 1), (85, 1)]
22 1


IndexError: index 21 is out of bounds for dimension 0 with size 0

In [16]:
pp(f'{len(cache)=}')
pp(f'{cache[48].shape=}')

'len(cache)=49'
'cache[48].shape=torch.Size([85, 1600])'


In [17]:
with torch.inference_mode():
    #res = reporter(cache_true['mlp_out', 47][0]).sigmoid()
    res = reporter(cache[48].to(device)).sigmoid()
pp(res)
pp(dataset[0][1])
for inx, label in dataset[0][1]:
    print(inx, label)
    pp(res[inx-1])

tensor([0.5073, 0.5277, 0.4920, 0.4879, 0.4992, 0.5295, 0.5069, 0.4075, 0.4524,
        0.4878, 0.4294, 0.5153, 0.4835, 0.4480, 0.3814, 0.4024, 0.5890, 0.4783,
        0.5086, 0.4345, 0.4571, 0.4291, 0.4734, 0.5777, 0.4932, 0.4426, 0.4322,
        0.4833, 0.4841, 0.4286, 0.4520, 0.4902, 0.4626, 0.4855, 0.4911, 0.4487,
        0.4338, 0.4346, 0.5362, 0.4935, 0.4344, 0.3824, 0.5053, 0.4290, 0.4144,
        0.4524, 0.5476, 0.4920, 0.4939, 0.4115, 0.4343, 0.4530, 0.4644, 0.4997,
        0.4401, 0.4185, 0.4767, 0.4283, 0.5502, 0.4634, 0.3909, 0.4002, 0.3929,
        0.4459, 0.3868, 0.4012, 0.4672, 0.5112, 0.4316, 0.4788, 0.4659, 0.4184,
        0.4165, 0.4835, 0.4615, 0.4314, 0.4360, 0.5613, 0.4773, 0.5183, 0.4896,
        0.4785, 0.4516, 0.4964, 0.3631], device='cuda:0')
[(22, 1), (42, 1), (63, 1), (85, 1)]
22 1
tensor(0.4291, device='cuda:0')
42 1
tensor(0.3824, device='cuda:0')
63 1
tensor(0.3929, device='cuda:0')
85 1
tensor(0.3631, device='cuda:0')


In [18]:
with torch.inference_mode():
    #res = reporter(cache_true['mlp_out', 47][0]).sigmoid()
    res = reporter(cache[47].to(device)).sigmoid()
pp(res)
pp(dataset[0][1])
for inx, label in dataset[0][1]:
    print(inx, label)
    pp(res[inx-1])

tensor([[0.3050, 0.9935, 0.9124, 0.7929, 0.9053, 0.9734, 0.8366, 0.1308, 0.8611,
         0.9044, 0.3054, 0.9358, 0.9229, 0.8388, 0.1966, 0.2088, 0.9983, 0.8936,
         0.9640, 0.5051, 0.8237, 0.5409, 0.9603, 0.9991, 0.9582, 0.7825, 0.6045,
         0.9315, 0.7711, 0.4138, 0.9390, 0.9523, 0.8545, 0.9958, 0.9755, 0.8188,
         0.8179, 0.7204, 0.9831, 0.9190, 0.5723, 0.2966, 0.9685, 0.4541, 0.2366,
         0.6938, 0.9950, 0.9063, 0.8278, 0.3497, 0.5442, 0.9244, 0.8659, 0.9859,
         0.8131, 0.3358, 0.8373, 0.4717, 0.9989, 0.8913, 0.2677, 0.2569, 0.3490,
         0.8148, 0.4340, 0.2194, 0.8731, 0.9939, 0.5128, 0.8555, 0.7232, 0.4452,
         0.7410, 0.9695, 0.9449, 0.8220, 0.6168, 0.9952, 0.9450, 0.9938, 0.9010,
         0.9101, 0.6735, 0.9850, 0.1229]], device='cuda:0')
[(22, 1), (42, 1), (63, 1), (85, 1)]
22 1


IndexError: index 21 is out of bounds for dimension 0 with size 1

In [19]:
with torch.inference_mode():
    #res = reporter(cache_true['mlp_out', 47][0]).sigmoid()
    res = reporter(cache[47].to(device)).sigmoid()
pp(res.shape)
pp(dataset[0][1])
for inx, label in dataset[0][1]:
    print(inx, label)
    pp(res[inx-1])

torch.Size([1, 85])
[(22, 1), (42, 1), (63, 1), (85, 1)]
22 1


IndexError: index 21 is out of bounds for dimension 0 with size 1

In [20]:
with torch.inference_mode():
    #res = reporter(cache_true['mlp_out', 47][0]).sigmoid()
    res = reporter(cache[47].to(device))[0].sigmoid()
pp(res.shape)
pp(dataset[0][1])
for inx, label in dataset[0][1]:
    print(inx, label)
    pp(res[inx-1])

torch.Size([85])
[(22, 1), (42, 1), (63, 1), (85, 1)]
22 1
tensor(0.5409, device='cuda:0')
42 1
tensor(0.2966, device='cuda:0')
63 1
tensor(0.3490, device='cuda:0')
85 1
tensor(0.1229, device='cuda:0')


In [21]:
reporter = elk.training.Reporter.load(f'./data/gpt2-xl/ag_news/reporters/layer_47.pt', map_location=device)
reporter.eval()
pp(reporter)

CcsReporter(
  (norm): ConceptEraser()
  (probe): Sequential(
    (0): Linear(in_features=1600, out_features=1, bias=True)
  )
)


In [22]:
with torch.inference_mode():
    #res = reporter(cache_true['mlp_out', 47][0]).sigmoid()
    res = reporter(cache[47].to(device))[0].sigmoid()
pp(res.shape)
pp(dataset[0][1])
for inx, label in dataset[0][1]:
    print(inx, label)
    pp(res[inx-1])

torch.Size([85])
[(22, 1), (42, 1), (63, 1), (85, 1)]
22 1
tensor(0.7564, device='cuda:0')
42 1
tensor(0.4013, device='cuda:0')
63 1
tensor(0.5280, device='cuda:0')
85 1
tensor(0.4383, device='cuda:0')


In [23]:
reporter = elk.training.Reporter.load(f'./data/gpt2-xl/dbpedia_14/lr_models/layer_47.pt', map_location=device)
reporter.eval()
pp(reporter)

TypeError: Expected a `dict` or `Reporter` object, but got <class 'list'>.

In [24]:
tensor = torch.load(f'./data/gpt2-xl/dbpedia_14/lr_models/layer_47.pt', map_location=device)
pp(tensor)

[Classifier(
  (linear): Linear(in_features=1600, out_features=1, bias=True)
)]


In [25]:
tensor = torch.load(f'./data/gpt2-xl/dbpedia_14/lr_models/layer_47.pt', map_location=device)
pp(tensor)
#reporter = elk.training.Reporter.load(f'./data/gpt2-xl/dbpedia_14/lr_models/layer_47.pt', map_location=device)
#reporter.eval()
#pp(reporter)

[Classifier(
  (linear): Linear(in_features=1600, out_features=1, bias=True)
)]


In [26]:
reporter = torch.load(f'./data/gpt2-xl/dbpedia_14/lr_models/layer_47.pt', map_location=device)
pp(reporter)

[Classifier(
  (linear): Linear(in_features=1600, out_features=1, bias=True)
)]


In [27]:
with torch.inference_mode():
    #res = reporter(cache_true['mlp_out', 47][0]).sigmoid()
    res = reporter(cache[47].to(device))[0].sigmoid()
pp(res.shape)
pp(dataset[0][1])
for inx, label in dataset[0][1]:
    print(inx, label)
    pp(res[inx-1])

TypeError: 'list' object is not callable

In [28]:
with torch.inference_mode():
    #res = reporter(cache_true['mlp_out', 47][0]).sigmoid()
    res = reporter(cache[47].to(device))
    pp(res)
    res = [0].sigmoid()
pp(res.shape)
pp(dataset[0][1])
for inx, label in dataset[0][1]:
    print(inx, label)
    pp(res[inx-1])

TypeError: 'list' object is not callable

In [29]:
reporter = torch.load(f'./data/gpt2-xl/dbpedia_14/lr_models/layer_47.pt', map_location=device)[0]
pp(reporter)

Classifier(
  (linear): Linear(in_features=1600, out_features=1, bias=True)
)


In [30]:
with torch.inference_mode():
    #res = reporter(cache_true['mlp_out', 47][0]).sigmoid()
    res = reporter(cache[47].to(device))
    pp(res)
    res = [0].sigmoid()
pp(res.shape)
pp(dataset[0][1])
for inx, label in dataset[0][1]:
    print(inx, label)
    pp(res[inx-1])

tensor([[-3.1554e+04,  2.9691e+04,  4.8190e+04, -7.0022e+03,  3.0838e+04,
          7.5812e+04, -1.6684e+04,  5.7554e+04,  5.1635e+04, -6.7735e+04,
         -4.5031e+04,  2.9842e+04,  3.6086e+04,  1.5849e+04,  1.6806e+04,
         -4.8002e+04,  4.2133e+04,  6.3501e+01,  1.2088e+05,  8.0037e+03,
          5.9227e+04,  6.6244e+04,  3.0824e+04,  1.1024e+05,  2.6843e+04,
         -2.0253e+03,  6.2194e+04,  6.2850e+03, -4.8561e+04,  3.5525e+04,
          3.2681e+04, -7.6061e+04,  5.2919e+04,  9.0775e+04,  2.0617e+04,
          2.8052e+03,  6.4431e+04,  1.5134e+04,  4.5953e+04, -1.9247e+04,
          9.7630e+04,  2.7276e+04,  3.0647e+04,  4.2342e+04, -1.1101e+03,
         -2.3795e+04,  1.6483e+05,  4.5880e+04,  6.5606e+04, -2.5157e+04,
         -7.7115e+03,  9.4512e+04, -5.0809e+04,  7.8759e+04,  6.9595e+04,
          2.9231e+03,  4.7988e+04,  3.6668e+03,  5.4307e+04,  3.1461e+04,
         -4.5519e+03,  2.5983e+04,  4.1848e+04, -9.1882e+03,  1.9365e+04,
         -5.7895e+03,  7.9265e+03,  4.

AttributeError: 'list' object has no attribute 'sigmoid'

In [31]:
with torch.inference_mode():
    #res = reporter(cache_true['mlp_out', 47][0]).sigmoid()
    res = reporter(cache[47].to(device))
    res = res[0].sigmoid()
pp(res.shape)
pp(dataset[0][1])
for inx, label in dataset[0][1]:
    print(inx, label)
    pp(res[inx-1])

torch.Size([85])
[(22, 1), (42, 1), (63, 1), (85, 1)]
22 1
tensor(1., device='cuda:0')
42 1
tensor(1., device='cuda:0')
63 1
tensor(1., device='cuda:0')
85 1
tensor(1., device='cuda:0')


In [32]:
with torch.inference_mode():
    #res = reporter(cache_true['mlp_out', 47][0]).sigmoid()
    res = reporter(cache[47].to(device))
    res = res[0]
pp(res.shape)
pp(dataset[0][1])
for inx, label in dataset[0][1]:
    print(inx, label)
    pp(res[inx-1])

torch.Size([85])
[(22, 1), (42, 1), (63, 1), (85, 1)]
22 1
tensor(66244.1094, device='cuda:0')
42 1
tensor(27276.3223, device='cuda:0')
63 1
tensor(41848.4961, device='cuda:0')
85 1
tensor(33222.6602, device='cuda:0')


In [33]:
with torch.inference_mode():
    #res = reporter(cache_true['mlp_out', 47][0]).sigmoid()
    res = reporter(cache[47].to(device))
    res = res[0]
pp(res.shape)
pp(res)
pp(dataset[0][1])
for inx, label in dataset[0][1]:
    print(inx, label)
    pp(res[inx-1])

torch.Size([85])
tensor([-3.1554e+04,  2.9691e+04,  4.8190e+04, -7.0022e+03,  3.0838e+04,
         7.5812e+04, -1.6684e+04,  5.7554e+04,  5.1635e+04, -6.7735e+04,
        -4.5031e+04,  2.9842e+04,  3.6086e+04,  1.5849e+04,  1.6806e+04,
        -4.8002e+04,  4.2133e+04,  6.3501e+01,  1.2088e+05,  8.0037e+03,
         5.9227e+04,  6.6244e+04,  3.0824e+04,  1.1024e+05,  2.6843e+04,
        -2.0253e+03,  6.2194e+04,  6.2850e+03, -4.8561e+04,  3.5525e+04,
         3.2681e+04, -7.6061e+04,  5.2919e+04,  9.0775e+04,  2.0617e+04,
         2.8052e+03,  6.4431e+04,  1.5134e+04,  4.5953e+04, -1.9247e+04,
         9.7630e+04,  2.7276e+04,  3.0647e+04,  4.2342e+04, -1.1101e+03,
        -2.3795e+04,  1.6483e+05,  4.5880e+04,  6.5606e+04, -2.5157e+04,
        -7.7115e+03,  9.4512e+04, -5.0809e+04,  7.8759e+04,  6.9595e+04,
         2.9231e+03,  4.7988e+04,  3.6668e+03,  5.4307e+04,  3.1461e+04,
        -4.5519e+03,  2.5983e+04,  4.1848e+04, -9.1882e+03,  1.9365e+04,
        -5.7895e+03,  7.9265e+03, 

In [34]:
with torch.inference_mode():
    #res = reporter(cache_true['mlp_out', 47][0]).sigmoid()
    res = reporter(cache[47].to(device))
    res = res[0]
pp(res.shape)
pp(res > 0)
pp(dataset[0][1])
for inx, label in dataset[0][1]:
    print(inx, label)
    pp(res[inx-1])

torch.Size([85])
tensor([False,  True,  True, False,  True,  True, False,  True,  True, False,
        False,  True,  True,  True,  True, False,  True,  True,  True,  True,
         True,  True,  True,  True,  True, False,  True,  True, False,  True,
         True, False,  True,  True,  True,  True,  True,  True,  True, False,
         True,  True,  True,  True, False, False,  True,  True,  True, False,
        False,  True, False,  True,  True,  True,  True,  True,  True,  True,
        False,  True,  True, False,  True, False,  True,  True,  True, False,
         True, False,  True, False,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True], device='cuda:0')
[(22, 1), (42, 1), (63, 1), (85, 1)]
22 1
tensor(66244.1094, device='cuda:0')
42 1
tensor(27276.3223, device='cuda:0')
63 1
tensor(41848.4961, device='cuda:0')
85 1
tensor(33222.6602, device='cuda:0')


In [35]:
with torch.inference_mode():
    #res = reporter(cache_true['mlp_out', 47][0]).sigmoid()
    res = reporter(cache[47].to(device))
    res = res[0]
pp(res.shape)
pp(res)
pp(dataset[0][1])
for inx, label in dataset[0][1]:
    print(inx, label)
    pp(res[inx-1])

torch.Size([85])
tensor([-3.1554e+04,  2.9691e+04,  4.8190e+04, -7.0022e+03,  3.0838e+04,
         7.5812e+04, -1.6684e+04,  5.7554e+04,  5.1635e+04, -6.7735e+04,
        -4.5031e+04,  2.9842e+04,  3.6086e+04,  1.5849e+04,  1.6806e+04,
        -4.8002e+04,  4.2133e+04,  6.3501e+01,  1.2088e+05,  8.0037e+03,
         5.9227e+04,  6.6244e+04,  3.0824e+04,  1.1024e+05,  2.6843e+04,
        -2.0253e+03,  6.2194e+04,  6.2850e+03, -4.8561e+04,  3.5525e+04,
         3.2681e+04, -7.6061e+04,  5.2919e+04,  9.0775e+04,  2.0617e+04,
         2.8052e+03,  6.4431e+04,  1.5134e+04,  4.5953e+04, -1.9247e+04,
         9.7630e+04,  2.7276e+04,  3.0647e+04,  4.2342e+04, -1.1101e+03,
        -2.3795e+04,  1.6483e+05,  4.5880e+04,  6.5606e+04, -2.5157e+04,
        -7.7115e+03,  9.4512e+04, -5.0809e+04,  7.8759e+04,  6.9595e+04,
         2.9231e+03,  4.7988e+04,  3.6668e+03,  5.4307e+04,  3.1461e+04,
        -4.5519e+03,  2.5983e+04,  4.1848e+04, -9.1882e+03,  1.9365e+04,
        -5.7895e+03,  7.9265e+03, 