In [1]:
import os
import torch

os.environ["CUDA_VISIBLE_DEVICES"] = "6"
torch.cuda.device_count()

1

In [2]:
from transformer_lens import HookedTransformer
from sae_lens import SAE
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"

# get model
model = HookedTransformer.from_pretrained("google/gemma-2-2b", device = device)

# layer = 13

# # get the SAE for this layer
# sae, cfg_dict, _ = SAE.from_pretrained(
#     release = "gemma-scope-2b-pt-att-canonical",
#     sae_id = f"layer_{layer}/width_16k/canonical",
#     device = device
# )

# # get hook point
# hook_point = sae.cfg.hook_name

  from .autonotebook import tqdm as notebook_tqdm
Loading checkpoint shards: 100%|██████████| 3/3 [00:07<00:00,  2.36s/it]


Loaded pretrained model google/gemma-2-2b into HookedTransformer


In [3]:
import json
import numpy as np
with open("cache/accuracies_sorted_gemma.json", "r") as file:
    accuracies_sorted = json.load(file)

accuracies_by_layer = {layer: np.mean([p[0] for p in accuracies_sorted if p[1] == layer]) for layer in range(26)}
accuracies_by_layer_sorted = sorted([[acc, l] for l, acc in accuracies_by_layer.items()])

k = 1
alpha = 10
# with open("cache/steering_directions.json", "r") as file:
#     steering_directions_json = json.load(file)
# steering_directions = {}
# for item in steering_directions_json:
#     steering_directions[eval(item)] = np.array(steering_directions_json[item])
# with open("cache/stds.json", "r") as file:
#     stds_json = json.load(file)
# stds = {}
# for item in stds_json:
#     stds[eval(item)] = stds_json[item]

In [6]:
from tqdm import tqdm
import tasks.truthfulqa
from torch.utils.data import DataLoader
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import numpy as np
import torch
from sae_lens import SAE

steering_directions = []
stds = []
hook_names = []

task = tasks.truthfulqa.TruthfulQA("probing")
tokenized_dataset = task.get_tokenized_dataset(tokenizer=model.tokenizer, batch_size=2, subset=False,
                                                random_seed=42, subset_len=81, max_length=1000)

layers = [l for a, l in accuracies_by_layer_sorted[:k]]

for layer in tqdm(layers):
    sae, cfg_dict, _ = SAE.from_pretrained(
        release = "gemma-scope-2b-pt-att-canonical",
        sae_id = f"layer_{layer}/width_16k/canonical",
        device = device
    )

    hook_point = sae.cfg.hook_name
    hook_names.append(hook_point)
    sae.eval()
    correct_activations = []

    dataloader = DataLoader(tokenized_dataset, batch_size=2, shuffle=False)

    with torch.no_grad():
        for batch in tqdm(dataloader):
            batch_tokens = batch["tokens"]
            _, cache = model.run_with_cache(batch_tokens, prepend_bos=True)

            feature_acts = sae.encode(cache[sae.cfg.hook_name])
            correct_activations.append(feature_acts[np.arange(batch_tokens.shape[0]), batch["len_of_input"] - 1, :].detach().cpu())
            del cache

    correct_activations_dataset = torch.vstack(correct_activations)
    X = correct_activations_dataset.numpy()
    y = np.array([1 if item["label"] == "True" else -1 for item in tokenized_dataset])
    X_correct = X[y == 1]
    X_incorrect = X[y == -1]
    steering_direction = (X_correct.mean(axis=0) - X_incorrect.mean(axis=0))
    direction_norm = np.linalg.norm(steering_direction)
    std = np.std(np.dot(X, steering_direction) / direction_norm)
    steering_directions.append(steering_direction)
    stds.append(std)

100%|██████████| 409/409 [01:20<00:00,  5.06it/s]
100%|██████████| 1/1 [01:21<00:00, 81.88s/it]


In [47]:
np.indices(steering_directions[0].shape)[0]

array([    0,     1,     2, ..., 16381, 16382, 16383])

In [50]:
np.indices(steering_directions[0].shape)[0][np.abs(steering_directions[0]) > 0.1]

array([   36,  1050,  1418,  1777,  2577,  3036,  3360,  3466,  4421,
        4473,  4783,  4956,  5504,  5568,  6027,  6406,  7218,  7464,
        7686,  8702,  9055, 11123, 11691, 12773, 13009, 13235, 13411,
       13831, 14368, 14855, 15948, 15974])

In [25]:
steering_directions_in_attn_space = []
for direction in steering_directions:
    print(np.count_nonzero(direction), direction.shape)
    steering_directions_in_attn_space.append(direction @ sae.W_dec.detach().cpu().numpy())

3194 (16384,)


In [27]:
import json
with open("cache/steering_directions_gemma_sae_full.json", "w") as file:
    json.dump({str(layer): direction.tolist() for layer, direction in zip(layers, steering_directions_in_attn_space)}, file)
with open("cache/stds_gemma_sae_full.json", "w") as file:
    json.dump({str(layer): std.item() for layer, std in zip(layers, stds)}, file)