# An investigation into use of SAEs as steering vectors

## Setup

### If in Colab, install deps. Otherwise, setup autoreload.

In [1]:
try:
    import google.colab

    IN_COLAB = True
    %pip install sae-lens transformer-lens

except ImportError:
    # Local
    IN_COLAB = False

    import IPython

    ipython = IPython.get_ipython()
    ipython.run_line_magic("load_ext", "autoreload")
    ipython.run_line_magic("autoreload", "2")

### Imports

In [2]:
import collections
import functools
import json
import math
import os
import requests
from pprint import pprint
import pathlib

import huggingface_hub
from matplotlib import pyplot as plt
import pandas as pd
from sae_lens import SAE
from tqdm import tqdm, trange
import torch as t
from transformer_lens import HookedTransformer

### Other settings

In [None]:
t.set_grad_enabled(False)

if t.backends.mps.is_available():
    DEVICE = "mps"
else:
    DEVICE = "cuda" if t.cuda.is_available() else "cpu"

print(f"\nDevice: {DEVICE}")

### HuggingFace Login (for Gemma)

In [None]:
huggingface_hub.notebook_login()

## Load GPT-2

In [None]:
gpt2_small = HookedTransformer.from_pretrained(
    "gpt2-small",
    device=DEVICE,
)

pprint(gpt2_small.cfg)

## Import GPT-2 J-B SAEs

In [None]:
saes = [
    SAE.from_pretrained(
        release="gpt2-small-res-jb",
        sae_id=f"blocks.{layer}.hook_resid_pre",
        device=DEVICE,
    )[0]
    for layer in trange(gpt2_small.cfg.n_layers)
]

pprint(saes[0].cfg)

## Export SAE feature explanations for later search

In [7]:
def load_explanations_from_saes(saes, save_path):
    try:
        with open(save_path, "r") as f:
            explanations = json.load(f)

    except FileNotFoundError:
        url = "https://www.neuronpedia.org/api/explanation/export"

        explanations = []

        for i, sae in enumerate(tqdm(saes)):
            model, sae_id = sae.cfg.neuronpedia_id.split("/")

            querystring = {"modelId": model, "saeId": sae_id}
            headers = {"X-Api-Key": os.getenv("NEURONPEDIA_TOKEN")}

            response = requests.get(url, headers=headers, params=querystring)
            explanations += response.json()

            with open(save_path, "w") as f:
                json.dump(explanations, f, indent=2)

    return explanations, save_path


explanations_fpath = "gpt2-small_res-jb_explanations.json"
explanations, _ = load_explanations_from_saes(saes, explanations_fpath)

## Find all features whose explanations contain keywords

In [None]:
def get_explanations_with_keywords_by_layer(explanations, keywords):
    explanations_filtered = collections.defaultdict(list)

    for explanation in explanations:
        if any(keyword in explanation["description"].upper() for keyword in keywords):
            layer = int(explanation["layer"].split("-")[0])
            explanations_filtered[layer].append(explanation)

    return explanations_filtered


keywords = ["AUSTRALIA"]

explanations_filtered = get_explanations_with_keywords_by_layer(explanations, keywords)

explanation_count = 0
explanation_counts_by_layer = {}
for layer in range(len(saes)):
    explanation_count_in_layer = len(explanations_filtered[layer])
    explanation_counts_by_layer[layer] = explanation_count_in_layer
    explanation_count += explanation_count_in_layer
    print(
        f"\nNumber of relevant features in layer {layer}: {explanation_count_in_layer}"
    )
    explanations_filtered_layer_str = "\n\t".join(
        [
            f"{explanation['index']}:\t{explanation['description']}"
            for explanation in explanations_filtered[layer]
        ]
    )
    print(f"\t{explanations_filtered_layer_str}")
print(
    f"Total relevant features: {explanation_count} ({explanation_count / len(explanations):.6f}% of total)"
)


In [None]:
fig = plt.figure(figsize=(10, 5))

# creating the bar plot
plt.bar(
    explanation_counts_by_layer.keys(),
    explanation_counts_by_layer.values(),
    color="blue",
    width=0.4,
)

plt.xlabel("SAE/Layer #")
plt.xticks([l for l in explanation_counts_by_layer.keys()])
plt.ylabel(f"No. of features containing 'Australia'")
plt.title("No. of features containing 'Australia' by SAE/Layer")
plt.show()

## Find SAE feature indices that correlate with intended steering direction 

In [None]:
def run_model_and_get_filtered_activations(
    model, prompt, explanations_filtered, activation_threshold, quiet=True
):
    _, cache = model.run_with_cache(prompt, prepend_bos=True)

    if not quiet:
        tokens = model.to_tokens(prompt)
        print(f"Tokens: {tokens}")
        print(f"Token strings: {model.to_str_tokens(tokens)}")

    saes_out = {}
    acts_filtered = {}

    for layer, sae in enumerate(tqdm(saes)):
        if explanations_filtered[layer] == []:
            continue

        explanations_filtered_idx = t.tensor(
            [int(explanation["index"]) for explanation in explanations_filtered[layer]],
            device=DEVICE,
        )

        feature_acts = sae.encode(
            cache[sae.cfg.hook_name]
        )  # shape (batch, sequence, features)

        saes_out[layer] = sae.decode(feature_acts)

        feature_acts_vals_sorted, idx = feature_acts[
            :, :, explanations_filtered_idx
        ].sort(descending=True, dim=-1)

        feature_acts_idx_sorted = explanations_filtered_idx[idx]

        mask = feature_acts_vals_sorted >= activation_threshold

        acts_filtered[layer] = {
            "val": feature_acts_vals_sorted[mask].tolist(),
            "idx": feature_acts_idx_sorted[mask].tolist(),
            "desc": [],
        }

        for idx in acts_filtered[layer]["idx"]:
            acts_filtered[layer]["desc"].append(
                next(
                    expl["description"]
                    for expl in explanations_filtered[layer]
                    if int(expl["index"]) == idx
                )
            )

    return (
        cache,
        acts_filtered,
        saes_out,
    )


sv_prompt = "Sydney Opera House"
activation_threshold = 1.0

cache, acts_filtered, saes_out = run_model_and_get_filtered_activations(
    gpt2_small, sv_prompt, explanations_filtered, activation_threshold, quiet=True
)

# print_relevant_features(act_vals, act_idx, gpt2_small)

print("\nFiltered activations:")
pprint(acts_filtered)

In [None]:
steering_vectors_positions_by_layer = {}
for layer, acts_filt_at_layer in acts_filtered.items():
    if len(acts_filt_at_layer["idx"]) == 0:
        continue

    position = saes_out[layer].shape[1] - 1
    idx = acts_filt_at_layer["idx"][0]
    val = acts_filt_at_layer["val"][0]
    steering_vector = saes[layer].W_dec[idx]
    explanation = next(
        expl["description"]
        for expl in explanations_filtered[layer]
        if int(expl["index"]) == idx
    )

    steering_vectors_positions_by_layer[layer] = {
        "idx": idx,
        "pos": position,
        "sv": steering_vector,
        "val": val,
        "expl": explanation,
    }
svps_by_layer_to_print = {
    layer: {
        "idx": svp["idx"],
        "val": svp["val"],
        "expl": svp["expl"],
    }
    for layer, svp in steering_vectors_positions_by_layer.items()
}
print(f"Total number of steering vectors: {len(steering_vectors_positions_by_layer)}\n")
print(f"Steering vector layers: {[k for k in steering_vectors_positions_by_layer]}\n")

for layer, svp in steering_vectors_positions_by_layer.items():
    print(
        f"Layer {layer}:\n\tIndex:\t\t{svp['idx']}\n\tValue:\t\t{svp['val']}\n\tExplanation:\t{svp['expl']}\n"
    )

In [12]:
def steering_hook_all_layers(
    resid_pre, hook, steering_on, steering_vector, coeff, position
):
    # position = sae_out.shape[1]
    if steering_on:
        resid_pre[:, : position - 1, :] += coeff * steering_vector


def hooked_generate(
    prompt_batch, model, fwd_hooks=[], max_new_tokens=20, seed=None, **kwargs
):
    if seed is not None:
        t.manual_seed(seed)

    with model.hooks(fwd_hooks=fwd_hooks):
        tokenized = model.to_tokens(prompt_batch)
        result = model.generate(
            stop_at_eos=False,  # avoids a bug on MPS
            input=tokenized,
            max_new_tokens=max_new_tokens,
            do_sample=True,
            **kwargs,
        )
    return result


def generate_multi_layer(
    prompt: str,
    model: HookedTransformer,
    steering_vectors_positions_by_layer: dict,
    coeff: float,
    sampling_kwargs,
    num_responses: int = 3,
    steering_on: bool = True,
    max_new_tokens: int = 20,
):
    model.reset_hooks()

    editing_hooks = []
    for layer, svp in steering_vectors_positions_by_layer.items():
        temp_hook_fn = functools.partial(
            steering_hook_all_layers,
            steering_on=steering_on,
            steering_vector=svp["sv"],
            coeff=coeff,
            position=svp["pos"],
        )
        editing_hooks.append((f"blocks.{layer}.hook_resid_post", temp_hook_fn))

    res = hooked_generate(
        [prompt] * num_responses,
        model,
        editing_hooks,
        max_new_tokens,
        seed=None,
        **sampling_kwargs,
    )

    return res


def print_res_str(res, model):
    res_str = model.to_string(res[:, 1:])
    for i in range(res.shape[0]):
        print(f"{res_str[i]}\n" + "-" * 80)


def print_res_str_tokens(res, model):
    for i in range(res.shape[0]):
        res_str = model.to_str_tokens(res[i, 1:])
        print(f"{res_str}\n" + "-" * 80)

In [None]:
prompt = "The White House is located in a country called"
coeff = 50

sampling_kwargs = {"temperature": 1.0, "top_p": 0.1, "freq_penalty": 1.0}

res = generate_multi_layer(
    prompt=prompt,
    model=gpt2_small,
    steering_vectors_positions_by_layer=steering_vectors_positions_by_layer,
    steering_on=False,
    coeff=coeff,
    sampling_kwargs=sampling_kwargs,
    num_responses=3,
)

print("UNSTEERED:")
print_res_str(res, gpt2_small)

res = generate_multi_layer(
    prompt=prompt,
    model=gpt2_small,
    steering_vectors_positions_by_layer={2: steering_vectors_positions_by_layer[2]},
    steering_on=True,
    coeff=coeff,
    sampling_kwargs=sampling_kwargs,
    num_responses=3,
)

print("STEERED:")
print_res_str(res, gpt2_small)

In [None]:
# fmt: off
landmark_country_pairs = [
    ["The Eiffel Tower"         , "France"], # The unsteered model answers 'Switzerland' approx. 50% of the time
    ["The Taj Mahal"            , "Tajikistan"], # Actually, India I think, but model baseline is this
    ["The Statue of Liberty"    , "The United States of America"],
    ["The Colosseum"            , "Italy"],
    ["The Great Pyramid of Giza", "Egypt"],
    ["Stonehenge"               , "The United Kingdom"], # Gives better results than England
    ["Petra"                    , "Jordan"],
    ["Macchu Picchu"            , "Peru"],
    ["Burj Khalifa"             , "Qatar"], # Actually The United Arab Emirates
    ["The Kremlin"              , "Russia"],
]
# fmt: on
for i, (landmark, country) in enumerate(landmark_country_pairs):
    prompt = f"{landmark} is located in a country called"
    res = generate_multi_layer(
        prompt=prompt,
        model=gpt2_small,
        steering_vectors_positions_by_layer=steering_vectors_positions_by_layer,
        steering_on=False,
        coeff=50,
        num_responses=10,
        sampling_kwargs=sampling_kwargs,
        max_new_tokens=10,
    )
    print_res_str(res, gpt2_small)

In [15]:
def test_steering_of_landmark_location_in_model(
    landmark_country_pairs,
    model,
    steering_on,
    steering_vectors_positions_by_layer,
    steered_answer,
    coeff,
    quiet=False,
    strict=True,
    max_new_tokens=10,
):
    # country_max_length = max(len(country) for _, country in landmark_country_pairs)

    correct_counts = t.zeros(len(landmark_country_pairs), dtype=t.int)
    steering_counts = t.zeros(len(landmark_country_pairs), dtype=t.int)

    for i, (landmark, country) in enumerate(tqdm(landmark_country_pairs)):
        # if country != "Russia":
        #     continue
        prompt = f"{landmark} is located in a country called"
        res = generate_multi_layer(
            prompt=prompt,
            model=model,
            steering_vectors_positions_by_layer=steering_vectors_positions_by_layer,
            steering_on=steering_on,
            coeff=coeff,
            num_responses=100,
            sampling_kwargs=sampling_kwargs,
            max_new_tokens=max_new_tokens,
        )
        res_strs = gpt2_small.to_string(res[:, 1:])

        for res_str in res_strs:
            check_str = res_str[len(prompt) + 1 :].upper()
            if strict:
                if check_str.upper().startswith(country.upper()):
                    correct_counts[i] += 1
                if check_str.upper().startswith(steered_answer.upper()):
                    steering_counts[i] += 1
            else:
                if country.upper() in check_str:
                    correct_counts[i] += 1
                if steered_answer.upper() in check_str:
                    steering_counts[i] += 1

        if not quiet:
            print(
                f"{'STEERED' if steering_on else 'UNSTEERED'} test for landmark '{landmark}'"
            )
            print(f"Correct Answer: '{country}'\tSteered Answer: 'Australia'")
            print(f"\tNumber of correct responses: {correct_counts[i]}")
            print(f"\tNumber of steered responses: {steering_counts[i]}")

    return correct_counts, steering_counts


## Unsteered results:

In [None]:
correct_counts, steering_counts = test_steering_of_landmark_location_in_model(
    model=gpt2_small,
    landmark_country_pairs=landmark_country_pairs,
    steering_on=False,
    steering_vectors_positions_by_layer=steering_vectors_positions_by_layer,
    steered_answer="Australia",
    coeff=coeff,
    quiet=True,
    strict=True,
)

In [None]:
def print_steering_results(landmark_country_pairs, correct_counts, steering_counts):
    df = pd.DataFrame(
        {
            "Landmarks": [
                landmark_country_pairs[i][0] for i in range(len(landmark_country_pairs))
            ],
            "Expected Answers": [
                landmark_country_pairs[i][1] for i in range(len(landmark_country_pairs))
            ],
            "Percent Correct": correct_counts,
            "Percent Steered": steering_counts,
        }
    )
    df.style.format_index(str.upper, axis=1)
    pd.set_option("display.width", 200)

    print(df.to_string(index=False))


print_steering_results(landmark_country_pairs, correct_counts, steering_counts)

## Steered

In [None]:
correct_counts, steering_counts = test_steering_of_landmark_location_in_model(
    model=gpt2_small,
    landmark_country_pairs=landmark_country_pairs,
    steering_on=True,
    steering_vectors_positions_by_layer={2: steering_vectors_positions_by_layer[2]},
    steered_answer="Australia",
    coeff=50,
    quiet=True,
)

print_steering_results(landmark_country_pairs, correct_counts, steering_counts)

## Not strict

### 10 tokens

In [None]:
correct_counts, steering_counts = test_steering_of_landmark_location_in_model(
    model=gpt2_small,
    landmark_country_pairs=landmark_country_pairs,
    steering_on=True,
    steering_vectors_positions_by_layer={2: steering_vectors_positions_by_layer[2]},
    steered_answer="Australia",
    coeff=50,
    quiet=True,
    strict=False,
    max_new_tokens=10,
)

print_steering_results(landmark_country_pairs, correct_counts, steering_counts)

### 50 tokens

In [None]:
correct_counts, steering_counts = test_steering_of_landmark_location_in_model(
    model=gpt2_small,
    landmark_country_pairs=landmark_country_pairs,
    steering_on=True,
    steering_vectors_positions_by_layer={2: steering_vectors_positions_by_layer[2]},
    steered_answer="Australia",
    coeff=50,
    quiet=True,
    strict=False,
    max_new_tokens=50,
)

print_steering_results(landmark_country_pairs, correct_counts, steering_counts)

## Try a range of coeffs and steer using one feature at a time and compare

In [None]:
coeffs = [
    0.0,
    0.1,
    0.2,
    0.5,
    1.0,
    2.0,
    5.0,
    10.0,
    20.0,
    50.0,
    100.0,
    500.0,
    1000.0,
    5000.0,
]

correct_counts_all_strict = t.zeros(
    (
        len(steering_vectors_positions_by_layer),
        len(coeffs),
        len(landmark_country_pairs),
    ),
    dtype=t.int,
)
steering_counts_all_strict = correct_counts_all_strict.clone().detach()

i = 0
for layer, svp in tqdm(steering_vectors_positions_by_layer.items()):
    for c, coeff in enumerate(coeffs):
        steering_on = False if math.isclose(coeff, 0.0) else True

        correct_counts_all_strict[i][c], steering_counts_all_strict[i][c] = (
            test_steering_of_landmark_location_in_model(
                model=gpt2_small,
                landmark_country_pairs=landmark_country_pairs,
                steering_on=steering_on,
                steering_vectors_positions_by_layer={layer: svp},
                steered_answer="Australia",
                coeff=coeff,
                quiet=True,
            )
        )
    i += 1


In [None]:
print(f"Coeffs: {coeffs}")
for i, key in enumerate(steering_vectors_positions_by_layer.keys()):
    print(f"\nLayer {key}")
    print(f"Feature Index:\t\t{steering_vectors_positions_by_layer[key]['idx']}")
    print(f"Feature Act Val:\t{steering_vectors_positions_by_layer[key]['val']}")
    print(f"Feature Explanation:\t{steering_vectors_positions_by_layer[key]['expl']}")
    print("Correct counts:")
    pprint(correct_counts_all_strict[i])
    print("Steering counts:")
    pprint(steering_counts_all_strict[i])

In [None]:
import numpy as np
import matplotlib.pyplot as plt


def create_steering_plot(counts, steering_coeffs, suptitle, subplot_titles):
    fig, axs = plt.subplots(2, 3, figsize=(18, 10))
    fig.suptitle(suptitle, fontsize=16)
    axs_flat = axs.flatten()

    num_pairs = counts.shape[2]
    for i, ax in enumerate(axs_flat):
        counts_layer = counts[i]

        for j in range(num_pairs):
            ax.plot(
                steering_coeffs, counts_layer[:, j], label=landmark_country_pairs[j][1]
            )
            ax.set_xscale("log")
            ax.set_xlabel("Steering Coefficients")
            ax.set_ylabel("Values")
            ax.set_ylim([0, 105])
            ax.set_title(subplot_titles[i])

    handles, labels = axs_flat[-1].get_legend_handles_labels()
    fig.legend(handles, labels, loc="center right", bbox_to_anchor=(0.98, 0.5), ncol=1)
    # fig.legend(axs_flat, labels=, loc="right", bbox_to_anchor=(0.91, 0.5))

    plt.tight_layout(rect=[0, 0, 0.82, 0.98])
    plt.subplots_adjust(wspace=0.3, hspace=0.3)
    plt.show()


create_steering_plot(
    counts=correct_counts_all_strict,
    steering_coeffs=coeffs,
    suptitle="Steering Results (STRICT) - 'Correct' Answer Count with Increasing Steering Coefficient",
    subplot_titles=[
        f"Layer {k}, Feature Idx {svp['idx']}"
        for k, svp in steering_vectors_positions_by_layer.items()
    ],
)

In [None]:
create_steering_plot(
    counts=steering_counts_all_strict,
    steering_coeffs=coeffs,
    suptitle="Steering Results (STRICT)  - 'Steered' Answer Count with Increasing Steering Coefficient",
    subplot_titles=[
        f"Layer {k}, Feature Idx {svp['idx']}"
        for k, svp in steering_vectors_positions_by_layer.items()
    ],
)

## Unstrict_10

In [None]:
correct_counts_all_unstrict = t.zeros(
    (
        len(steering_vectors_positions_by_layer),
        len(coeffs),
        len(landmark_country_pairs),
    ),
    dtype=t.int,
)
steering_counts_all_unstrict = correct_counts_all_unstrict.clone().detach()

i = 0
for layer, svp in tqdm(steering_vectors_positions_by_layer.items()):
    for c, coeff in enumerate(coeffs):
        steering_on = False if math.isclose(coeff, 0.0) else True

        correct_counts_all_unstrict[i][c], steering_counts_all_unstrict[i][c] = (
            test_steering_of_landmark_location_in_model(
                model=gpt2_small,
                landmark_country_pairs=landmark_country_pairs,
                steering_on=steering_on,
                steering_vectors_positions_by_layer={layer: svp},
                steered_answer="Australia",
                coeff=coeff,
                quiet=True,
                strict=False,
                max_new_tokens=10,
            )
        )
    i += 1

In [None]:
print(f"Coeffs: {coeffs}")
for i, key in enumerate(steering_vectors_positions_by_layer.keys()):
    print(f"\nLayer {key}")
    print(f"Feature Index:\t\t{steering_vectors_positions_by_layer[key]['idx']}")
    print(f"Feature Act Val:\t{steering_vectors_positions_by_layer[key]['val']}")
    print(f"Feature Explanation:\t{steering_vectors_positions_by_layer[key]['expl']}")
    print("Correct counts:")
    pprint(correct_counts_all_unstrict[i])
    print("Steering counts:")
    pprint(steering_counts_all_unstrict[i])

In [None]:
create_steering_plot(
    counts=correct_counts_all_unstrict,
    steering_coeffs=coeffs,
    suptitle="Steering Results (NOT STRICT) - 'Correct' Answer Count with Increasing Steering Coefficient",
    subplot_titles=[
        f"Layer {k}, Feature Idx {svp['idx']}"
        for k, svp in steering_vectors_positions_by_layer.items()
    ],
)

In [None]:
create_steering_plot(
    counts=steering_counts_all_unstrict,
    steering_coeffs=coeffs,
    suptitle="Steering Results (NOT STRICT) - 'Steered' Answer Count with Increasing Steering Coefficient",
    subplot_titles=[
        f"Layer {k}, Feature Idx {svp['idx']}"
        for k, svp in steering_vectors_positions_by_layer.items()
    ],
)

In [None]:
# Strict steering results by landmark.
for k in range(steering_counts_all_strict.shape[2]):
    print(landmark_country_pairs[k])
    print(coeffs)
    print(steering_counts_all_strict[:, :, k])  # shape (steering vectors, coeffs)

In [None]:
# Strict steering results summed over landmarks.
steering_counts_summed_over_landmarks_strict = steering_counts_all_strict.sum(
    dim=2
)  # shape (steering vectors, coeffs)
steering_counts_summed_over_landmarks_coeffs_strict = (
    steering_counts_summed_over_landmarks_strict.sum(dim=1)
)

print(steering_counts_summed_over_landmarks_strict)
print(steering_counts_summed_over_landmarks_coeffs_strict)

layers = list(steering_vectors_positions_by_layer.keys())


def create_sum_plot(summed_counts, steering_coeffs, labels, suptitle):
    fig, ax = plt.subplots(1, 1)
    fig.suptitle(suptitle, fontsize=16)

    for i in range(len(labels)):
        ax.plot(steering_coeffs, summed_counts[i, :], label=labels[i])
        ax.set_xscale("log")
        ax.set_xlabel("Steering Coefficients")
        ax.set_xlim([1, 5000])
        ax.set_ylabel("Count")

    handles, labels = ax.get_legend_handles_labels()
    fig.legend(handles, labels, loc="center right", bbox_to_anchor=(1.2, 0.5), ncol=1)

    plt.tight_layout(rect=[-0.2, 0, 0.8, 0.98])
    plt.show()


def create_bar_plot(vals, labels, title, xlabel, ylabel, rotation=0):
    fig, ax = plt.subplots(1, 1)
    x_pos = np.arange(len(labels))
    ax.bar(x_pos, vals, align="center")
    ax.set_title(title, fontsize=16, pad=20)
    ax.set_xticks(np.arange(len(labels)), labels, rotation=rotation)
    ax.set_xlabel(xlabel, labelpad=20)
    ax.set_ylabel(ylabel)

    plt.show()


create_sum_plot(
    summed_counts=steering_counts_summed_over_landmarks_strict,
    steering_coeffs=coeffs,
    labels=[
        f"Layer {layer}, Feature Idx {svp['idx']}"
        for layer, svp in steering_vectors_positions_by_layer.items()
    ],
    suptitle="Steering Results Per Layer (STRICT):\n'Steered' Answer Counts Summed Over All Prompts",
)

create_bar_plot(
    steering_counts_summed_over_landmarks_coeffs_strict,
    layers,
    "Steering Results (STRICT):\n'Steered' Answer Counts Summed Over All Prompts and Coeffs",
    "Steering Vector (by layer #)",
    "Count",
)

In [None]:
landmarks = [pair[0] for pair in landmark_country_pairs]
print(landmarks)

# Strict steering results by across all layers.
steering_counts_summed_over_layers_strict = steering_counts_all_strict.sum(
    dim=0
)  # shape (coeffs, landmarks)
print(steering_counts_summed_over_layers_strict)

# Strict steering results by across all layers and coeffs.
steering_counts_summed_over_layers_coeffs_strict = (
    steering_counts_summed_over_layers_strict.sum(dim=0)
)  # shape (landmarks)
print(steering_counts_summed_over_layers_coeffs_strict)

create_sum_plot(
    summed_counts=steering_counts_summed_over_layers_strict.T,
    steering_coeffs=coeffs,
    labels=landmarks,
    suptitle="Steering Results Per Layer (STRICT):\n'Steered' Answer Counts Summed Over All Layers",
)

create_bar_plot(
    steering_counts_summed_over_layers_coeffs_strict,
    landmarks,
    "Steering Results (STRICT):\n'Steered' Answer Counts Summed Over All Layers, and Coeffs",
    "Landmark",
    "Count",
    rotation=90,
)

# Switch to Gemma 2

In [29]:
# gemma_2_2b = HookedTransformer.from_pretrained(
#     "gemma-2-2b",
#     device=DEVICE,
# )

# pprint(gemma_2_2b.cfg)

## Get Gemma 2 SAEs

In [30]:
# width = 16

# saes = []
# for i in trange(gemma_2_2b.cfg.n_layers):
#     print(f"Downloading canonical SAE for layer {i} and width {width}k")
#     saes.append(
#         SAE.from_pretrained(
#             release="gemma-scope-2b-pt-res-canonical",
#             sae_id=f"layer_{i}/width_{width}k/canonical",
#             device=DEVICE,
#         )[0]
#     )

# pprint(saes[0].cfg)

In [31]:
# for sae in tqdm(saes):
#     fpath = pathlib.Path(
#         f"./gemma-scope-2b-pt-res-canonical/layer_{sae.cfg.hook_layer}__width_{str(sae.cfg.d_sae)[:-3]}k"
#     )
#     fpath.mkdir(parents=True, exist_ok=True)
#     sae.save_model(fpath)

## Find all features whose explanations contain keywords

In [32]:
# EXPLANATIONS_GEMMA_FPATH = "gemma-scope-2b-pt-res-canonical-w16k_explanations.json"

# try:
#     with open(EXPLANATIONS_GEMMA_FPATH, "r") as f:
#         explanations = json.load(f)
# except FileNotFoundError:
#     url = "https://www.neuronpedia.org/api/explanation/export"

#     explanations = []

#     for i in trange(len(saes)):
#         sae = saes[i]
#         model, sae_id = sae.cfg.neuronpedia_id.split("/")

#         querystring = {"modelId": model, "saeId": sae_id}

#         headers = {"X-Api-Key": os.getenv("NEURONPEDIA_TOKEN")}

#         response = requests.get(url, headers=headers, params=querystring)

#         explanations += response.json()

#     with open(EXPLANATIONS_GEMMA_FPATH, "w") as f:
#         json.dump(explanations, f, indent=2)

In [33]:
# KEYWORDS = ["POTTER"]

# explanations_filtered = [[] for i in range(len(saes))]
# explanation_count = 0

# for explanation in explanations:
#     if any(keyword in explanation["description"].upper() for keyword in KEYWORDS):
#         layer = int(explanation["layer"].split("-")[0])
#         explanations_filtered[layer].append(explanation)
#         explanation_count += 1

# for i in range(len(saes)):
#     print(f"Number of relevant features in layer {i}: {len(explanations_filtered[i])}")

# print(f"Total relevant features: {explanation_count}")


In [34]:
# sv_prompt = "Albus Dumbledore"
# sv_logits, cache = gemma_2_2b.run_with_cache(sv_prompt, prepend_bos=True)
# tokens = gemma_2_2b.to_tokens(sv_prompt)
# str_tokens = gemma_2_2b.to_str_tokens(tokens)
# print(f"Tokens: {tokens}")
# print(f"Token strings: {str_tokens}")

# k = 6
# act_threshold_relative = 0.005

# saes_out = []
# sv_feature_acts_vals_sorted_all_layers = []
# sv_feature_acts_idx_sorted_all_layers = []

# for i, sae in enumerate(saes):
#     explanations_filtered_idx = t.tensor(
#         [int(explanation["index"]) for explanation in explanations_filtered[i]],
#         device=DEVICE,
#     )
#     if explanations_filtered_idx.numel() == 0:
#         continue

#     sv_feature_acts = sae.encode(cache[sae.cfg.hook_name])
#     saes_out.append(sae.decode(sv_feature_acts))

#     act_max = sv_feature_acts.max()
#     act_threshold = act_threshold_relative * act_max

#     sv_feature_acts_filtered = sv_feature_acts[:, :, explanations_filtered_idx]

#     if (
#         sv_feature_acts_filtered.sum() * sv_feature_acts_filtered.numel()
#         < act_threshold
#     ):
#         continue

#     sv_feature_acts_vals_sorted, sv_feature_acts_idx_sorted = (
#         sv_feature_acts_filtered.sort(descending=True, dim=-1)
#     )

#     print(f"\nSorted activations for layer {i}")
#     print(f"\tMax activation for layer {i}: {act_max}")
#     for token_idx, str_token in enumerate(str_tokens):
#         vals_for_token = sv_feature_acts_vals_sorted[0, token_idx]
#         idx_for_token = explanations_filtered_idx[sv_feature_acts_idx_sorted][
#             0, token_idx
#         ]
#         mask = vals_for_token >= act_threshold
#         if not mask.any():
#             continue
#         print(f"\tToken {token_idx}: '{str_token}'")
#         print(f"\t\tRelevant feature activations: {vals_for_token[mask].tolist()}")
#         print(f"\t\tRelevant feature indices: {idx_for_token[mask].tolist()}")

# # from sae_lens.analysis.neuronpedia_integration import get_neuronpedia_quick_list

# # get_neuronpedia_quick_list(
# #     sae=sae, features=sv_feature_acts_idx[:, :, :].flatten().tolist()
# # )

In [35]:
# STEERING_LAYER = 8
# steering_vector = saes[STEERING_LAYER].W_dec[3012]

# example_prompt = "My favourite protagonist in any fantasy novel series is named"
# coeff = 300
# sampling_kwargs = dict(temperature=1.0, top_p=0.1, freq_penalty=1.0)

# model = gemma_2_2b

# sae_out = saes_out[STEERING_LAYER]

In [36]:
# def steering_hook(resid_pre, hook):
#     if resid_pre.shape[1] == 1:
#         return

#     position = sae_out.shape[1]
#     if steering_on:
#         # using our steering vector and applying the coefficient
#         resid_pre[:, : position - 1, :] += coeff * steering_vector


# def hooked_generate(prompt_batch, fwd_hooks=[], seed=None, **kwargs):
#     if seed is not None:
#         t.manual_seed(seed)

#     with model.hooks(fwd_hooks=fwd_hooks):
#         tokenized = model.to_tokens(prompt_batch)
#         result = model.generate(
#             stop_at_eos=False,  # avoids a bug on MPS
#             input=tokenized,
#             max_new_tokens=50,
#             do_sample=True,
#             **kwargs,
#         )
#     return result


In [37]:
# def run_generate(example_prompt):
#     model.reset_hooks()
#     editing_hooks = [(f"blocks.{STEERING_LAYER}.hook_resid_post", steering_hook)]
#     res = hooked_generate(
#         [example_prompt] * 3, editing_hooks, seed=None, **sampling_kwargs
#     )

#     # Print results, removing the ugly beginning of sequence token
#     res_str = model.to_string(res[:, 1:])
#     print(("\n" + "-" * 80 + "\n").join(res_str))

In [38]:
# steering_on = True
# run_generate(example_prompt)

In [39]:
# steering_on = False
# run_generate(example_prompt)