In [1]:
import os
import torch

os.environ["CUDA_VISIBLE_DEVICES"] = "7"
torch.cuda.device_count()

1

In [2]:
from transformer_lens import HookedTransformer
from sae_lens import SAE
import torch
from sae_lens.toolkit.pretrained_saes import get_gpt2_res_jb_saes

device = "cuda" if torch.cuda.is_available() else "cpu"

# get model
model = HookedTransformer.from_pretrained("google/gemma-2-2b", device = device)

layer = 13

# get the SAE for this layer
sae, cfg_dict, _ = SAE.from_pretrained(
    release = "gemma-scope-2b-pt-att-canonical",
    sae_id = f"layer_{layer}/width_16k/canonical",
    device = device
)

# get hook point
hook_point = sae.cfg.hook_name
print(hook_point)

  from .autonotebook import tqdm as notebook_tqdm
Loading checkpoint shards: 100%|██████████| 3/3 [01:35<00:00, 31.82s/it]


Loaded pretrained model google/gemma-2-2b into HookedTransformer
blocks.13.attn.hook_z


In [3]:
from datasets import load_dataset 

dataset = load_dataset(
    "allenai/ai2_arc", "ARC-Easy",
    split="train",
    streaming=False,
)

def format_examples(examples) -> dict[str, list]:
    examples_formatted = {"sent": [],
                          "label": []}
    for example_question, example_choices, answer in zip(examples["question"], examples["choices"], examples["answerKey"]):
        have_correct_example, have_incorrect_example = False, False
        for choice, label in zip(example_choices["text"], example_choices["label"]):
            if label == answer and not have_correct_example:
                examples_formatted["sent"].append(example_question + " " + choice)
                examples_formatted["label"].append("True")
                have_correct_example = True
            elif not have_incorrect_example:
                examples_formatted["sent"].append(example_question + " " + choice)
                examples_formatted["label"].append("False")
                have_incorrect_example = True
    return examples_formatted

probing_dataset = dataset.map(format_examples, batched=True, batch_size=8, remove_columns=dataset.column_names)

print(len(probing_dataset))

probing_dataset = probing_dataset.shuffle(seed=42).select(range(500))
print(len(probing_dataset), len(probing_dataset.filter(lambda example: example["label"] == "True")))

4502
500 264


In [4]:
import numpy as np
from functools import partial

def tokenize(examples, column_name, tokenizer, max_length):
        tokenizer.padding_side = "right"
        text = examples[column_name]
        tokens = tokenizer(text, return_tensors="np", padding="longest", max_length=max_length)["input_ids"]
        len_of_input = np.argmax(tokens == tokenizer.pad_token_id, axis=1)
        assert (tokens[len_of_input == 0] != tokenizer.pad_token_id).all(), (len_of_input, tokens)
        len_of_input[len_of_input == 0] = tokens.shape[1]
        return {"tokens": tokens, "len_of_input": len_of_input}

tokenized_dataset = probing_dataset.map(
    partial(tokenize,
    column_name = "sent",
    tokenizer = model.tokenizer,
    max_length=sae.cfg.context_size),
    batched=True,
    batch_size=8,
    num_proc=None
)
tokenized_dataset.set_format(type="torch", columns=["tokens", "label", "len_of_input"])
tokenized_dataset[0]

{'label': 'False',
 'tokens': tensor([     2,  13033,   1134,   1546,   5476,    614,    573,  17930,    576,
            671,   4018,    675,    476,   2301,   6620,    578,   2910,   5182,
         235336,   3178,   5601,      0,      0,      0,      0,      0,      0,
              0,      0,      0,      0,      0,      0,      0,      0,      0,
              0,      0,      0,      0]),
 'len_of_input': tensor(21)}

In [5]:
sae.cfg.hook_name

'blocks.13.attn.hook_z'

In [6]:
from tqdm import tqdm
from torch.utils.data import DataLoader
from sklearn.linear_model import LogisticRegression
import numpy as np

batch_size = 8
head = 0
correct_activations = []

dataloader = DataLoader(tokenized_dataset, batch_size=batch_size, shuffle=False)

with torch.no_grad():
    for batch in tqdm(dataloader):
        batch_tokens = batch["tokens"]
        _, cache = model.run_with_cache(batch_tokens, prepend_bos=True)

        feature_acts = cache[sae.cfg.hook_name]
        correct_activations.append(feature_acts[np.arange(batch_tokens.shape[0]), batch["len_of_input"] - 1, head, :].detach().cpu())
        del cache

correct_activations_dataset = torch.vstack(correct_activations)

X = correct_activations_dataset.numpy()
y = np.array([1 if item["label"] == "True" else -1 for item in tokenized_dataset])

X_correct = X[y == 1]
X_incorrect = X[y == -1]
X_difference = (X_correct.mean(axis=0) - X_incorrect.mean(axis=0))

lr = LogisticRegression(penalty="l1", solver="liblinear").fit(X, y)

100%|██████████| 63/63 [00:22<00:00,  2.83it/s]


In [7]:
X_difference.shape

steering_vector = torch.tensor(X_difference)

In [9]:
def steering(activations, hook, head, steering_strength, steering_vector, max_act, pos_to_intervene=[None]):
    activations[np.arange(activations.shape[0]), pos_to_intervene, head, :] += max_act * steering_strength * steering_vector
    return activations
from utils.dataset import get_tokenized_arc_easy_for_testing
steering_vector = steering_vector.to(model.cfg.device)
    
tokenized_dataset = get_tokenized_arc_easy_for_testing(model, sae.cfg.context_size, "test", batch_size)
dataloader = DataLoader(tokenized_dataset, batch_size=batch_size, shuffle=False)
correct = 0
sum_logit_diff = 0
head = 0
with torch.no_grad():
    for iter, batch in tqdm(enumerate(dataloader)):
        batch_tokens = batch["tokens"]

        steering_hook = partial(
            steering,
            steering_vector=steering_vector,
            steering_strength=100.,
            max_act=1.,
            head=head,
            pos_to_intervene=[-1 for i in range(batch_tokens.shape[0])] # padding on the left
        )

        with model.hooks(fwd_hooks=[(sae.cfg.hook_name, steering_hook)]):
            logits = model(batch_tokens, return_type="logits").cpu()
        correct_logits = [model.tokenizer.convert_tokens_to_ids(chr(ord("A") + label)) for label in batch["correct_label"]]
        incorrect_logits = [
                [
                    model.tokenizer.convert_tokens_to_ids(chr(ord("A") + label)) for label in range(num_labels) if label != correct_label
                ]
                for correct_label, num_labels in zip(batch["correct_label"], batch["num_labels"])
            ]
        incorrect_logits_probs = [np.array([logits[batch_i, -1, incorrect_id] for incorrect_id in incorrect_logits[batch_i]]) for batch_i in range(len(incorrect_logits))]
        max_incorrect_logit = [incorrect_logits[batch_i][np.argmax(incorrect_logits_probs[batch_i])] for batch_i in range(len(incorrect_logits))]
        logit_diff = [logits[batch_i, -1, correct_logits[batch_i]] - logits[batch_i, -1, max_incorrect_logit[batch_i]] for batch_i in range(len(incorrect_logits))]
        
        correct += (np.array(logit_diff) > 0).sum()
        sum_logit_diff += (np.array(logit_diff)).sum()
correct / len(tokenized_dataset), sum_logit_diff / len(tokenized_dataset)

297it [03:26,  1.44it/s]


(0.6813973063973064, 1.5318286764099942)

In [15]:
def steering(activations, hook, head, steering_strength, steering_vector, max_act, pos_to_intervene=[None]):
    activations[np.arange(activations.shape[0]), pos_to_intervene, head, :] += max_act * steering_strength * steering_vector
    return activations
from utils.dataset import get_tokenized_arc_easy_for_testing
steering_vector = torch.zeros(256).to(model.cfg.device)
    
tokenized_dataset = get_tokenized_arc_easy_for_testing(model, sae.cfg.context_size, "test", batch_size)
dataloader = DataLoader(tokenized_dataset, batch_size=batch_size, shuffle=False)
correct = 0
sum_logit_diff = 0
head = 0
with torch.no_grad():
    for iter, batch in tqdm(enumerate(dataloader)):
        batch_tokens = batch["tokens"]

        steering_hook = partial(
            steering,
            steering_vector=steering_vector,
            steering_strength=1.,
            max_act=1.,
            head=head,
            pos_to_intervene=[-1 for i in range(batch_tokens.shape[0])] # padding on the left
        )

        with model.hooks(fwd_hooks=[(sae.cfg.hook_name, steering_hook)]):
            logits = model(batch_tokens, return_type="logits").cpu()
        correct_logits = [model.tokenizer.convert_tokens_to_ids(chr(ord("A") + label)) for label in batch["correct_label"]]
        incorrect_logits = [
                [
                    model.tokenizer.convert_tokens_to_ids(chr(ord("A") + label)) for label in range(num_labels) if label != correct_label
                ]
                for correct_label, num_labels in zip(batch["correct_label"], batch["num_labels"])
            ]
        incorrect_logits_probs = [np.array([logits[batch_i, -1, incorrect_id] for incorrect_id in incorrect_logits[batch_i]]) for batch_i in range(len(incorrect_logits))]
        max_incorrect_logit = [incorrect_logits[batch_i][np.argmax(incorrect_logits_probs[batch_i])] for batch_i in range(len(incorrect_logits))]
        logit_diff = [logits[batch_i, -1, correct_logits[batch_i]] - logits[batch_i, -1, max_incorrect_logit[batch_i]] for batch_i in range(len(incorrect_logits))]
        
        correct += (np.array(logit_diff) > 0).sum()
        sum_logit_diff += (np.array(logit_diff)).sum()
correct / len(tokenized_dataset), sum_logit_diff / len(tokenized_dataset)

Map: 100%|██████████| 2376/2376 [00:00<00:00, 2714.60 examples/s]
297it [03:23,  1.46it/s]


(0.7066498316498316, 2.1488056580225625)

In [None]:
strength:1 (0.7070707070707071, 2.160504623494967)
: 0 0.7066498316498316, 2.1488056580225625)
10: (0.718013468013468, 2.201320888819518)
