## Set Up (Just Run / Not Important)

In [1]:
from IPython import get_ipython  # type: ignore
ipython = get_ipython()
ipython.run_line_magic("load_ext", "autoreload")
ipython.run_line_magic("autoreload", "2")

import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"]="1"

import os
import torch
from tqdm import tqdm
import plotly.express as px
import pandas as pd

from sae_lens import SAE
from sae_lens import HookedSAETransformer
torch.set_grad_enabled(False)
device = "cuda" if torch.cuda.is_available() else "cpu"

In [2]:

# model = HookedSAETransformer.from_pretrained("Qwen/Qwen2.5-1.5B-Instruct")
model = HookedSAETransformer.from_pretrained("meta-llama/Llama-3.2-1B-Instruct")
print(f"Device: {device}")
model.to(device)
sae_choice = "saelens"
if sae_choice == "eleuther":
    sae = SAE.from_eleuther(
        release = "huypn16/sae-llama-3.2-1B-32x", # see other options in sae_lens/pretrained_saes.yaml
        sae_id = "layers.8", # won't always be a hook point
        device = device
    )
elif sae_choice == "saelens":
    sae_id = "blocks.8.hook_mlp_out"
    sae, cfg_dict, sparsity = SAE.from_pretrained(
        release = "llama-3.2-1B-mlp-math", # see other options in sae_lens/pretrained_saes.yaml
        sae_id = sae_id, # won't always be a hook point
        device = device
    )
elif sae_choice == "qwen":
    release = "huypn16/sae-qwen-2.5-1.5B-OMS-16x"
    sae, cfg_dict, sparsity = SAE.from_eleuther(
        release=release,  # see other options in sae_lens/pretrained_saes.yaml
        sae_id="layers.14",  # won't always be a hook point
        device=device,
    ) # type: ignore
else:
    raise ValueError(f"Invalid sae_choice: {sae_choice}")



Loading model config for Qwen/Qwen2.5-1.5B-Instruct
Loaded model config for {'d_model': 1536, 'd_head': 128, 'n_heads': 12, 'n_key_value_heads': 2, 'd_mlp': 8960, 'n_layers': 28, 'n_ctx': 2048, 'eps': 1e-06, 'd_vocab': 151936, 'act_fn': 'silu', 'use_attn_scale': True, 'initializer_range': 0.02, 'normalization_type': 'RMS', 'positional_embedding_type': 'rotary', 'rotary_base': 1000000.0, 'rotary_adjacent_pairs': False, 'rotary_dim': 128, 'tokenizer_prepends_bos': True, 'final_rms': True, 'gated_mlp': True, 'original_architecture': 'Qwen2ForCausalLM', 'tokenizer_name': 'Qwen/Qwen2.5-1.5B-Instruct'}




Loaded pretrained model Qwen/Qwen2.5-1.5B-Instruct into HookedTransformer
Device: cuda
Moving model to device:  cuda


This SAE has non-empty model_from_pretrained_kwargs. 
For optimal performance, load the model like so:
model = HookedSAETransformer.from_pretrained_no_processing(..., **cfg.model_from_pretrained_kwargs)


In [3]:
%%capture
from transformer_lens.utils import tokenize_and_concatenate
from datasets import load_dataset
dataset = load_dataset(
    path="lighteval/MATH",
    streaming=False,
)
dataset = dataset["test"]
dataset = dataset.rename_column("solution", "text")
print(dataset.column_names)
token_dataset = tokenize_and_concatenate(
    dataset=dataset,
    tokenizer=model.tokenizer,
    streaming=True,
    max_length=sae.cfg.context_size,
    add_bos_token=sae.cfg.prepend_bos,
)
# feature_text = "Add 2 numbers together" -> feature_vector
# overlaps = database(samples).query(feature_text)
# common_semantic = databse(samples).query(feature_vector)
# -> algorithm -> combination vector -> add, rescaling

In [4]:
sae.eval()  # prevents error if we're expecting a dead neuron mask for who grads

with torch.no_grad():
    # activation store can give us tokens.
    batch_tokens = token_dataset[:8]["tokens"]
    _, cache = model.run_with_cache_with_saes(batch_tokens, saes=[sae])
    feature_acts = cache["blocks.8.hook_resid_post.hook_sae_acts_pre"]
    top_k_acts = cache["blocks.8.hook_resid_post.hook_sae_acts_post"]
    # save some room
    del cache
    del batch_tokens
    del _
    # ignore the bos token, get the number of features that activated in each token, averaged accross batch and position
    l0 = (feature_acts[:, 1:] > 0).float().sum(-1).detach()
    print("average l0", l0.mean().item())
    px.histogram(l0.flatten().cpu().numpy()).show() # the distribution of the number of active features per token
    del l0

RuntimeError: The size of tensor a (1536) must match the size of tensor b (2048) at non-singleton dimension 2

In [4]:
# histogram of feature that activated
non_zero_activations = top_k_acts[:, 1:, :] > 0
indices = torch.where(non_zero_activations)[-1]  # Extract the column (M) indices
result = -torch.ones_like(non_zero_activations, dtype=torch.int64)  # Initialize with -1
result[non_zero_activations] = indices
features = result[result != -1].flatten().cpu().numpy()
# count the number of activations per feature
# import numpy as np
# unique, counts = np.unique(features, return_counts=True)
# print(unique, counts)
px.histogram(pd.DataFrame(features, columns=["features"])).show()

NameError: name 'top_k_acts' is not defined

In [6]:
prompt = "Solver the following equation: 44 - x = 32, what is x? 44 - x"
_, cache = model.run_with_cache_with_saes(prompt, saes=[sae])
topk = 10

hook_id = "blocks.14.hook_resid_post.hook_sae_acts_pre"
fig = px.line(
    cache[hook_id][0, -1, :].cpu().tolist(),
    title="Feature activations at the final token position",
    labels={"index": "Feature", "value": "Activation"},
)
# let's print the top 5 features and how much they fired
vals, inds = torch.topk(
    cache[hook_id][0, -1, :], topk
)
fig.show()

hook_id = "blocks.14.hook_resid_post.hook_sae_acts_post"
fig = px.line(
    cache[hook_id][0, -1, :].cpu().tolist(),
    title="Feature activations at the final token position",
    labels={"index": "Feature", "value": "Activation"},
)
# let's print the top 5 features and how much they fired
vals, inds = torch.topk(
    cache[hook_id][0, -1, :], topk
)
for val, ind in zip(vals, inds):
    print(f"Feature {ind} fired magniture {val:.2f}")
fig.show()

# for val, ind in zip(vals, inds):
#     print(f"Feature {ind} fired {val:.2f}")
#     html = get_dashboard_html(
#         sae_release="gpt2-small", sae_id="7-res-jb", feature_idx=ind
#     )
#     display(IFrame(html, width=1200, height=300))

Feature 21719 fired magniture 29.53
Feature 630 fired magniture 14.65
Feature 15876 fired magniture 11.37
Feature 7384 fired magniture 9.15
Feature 4411 fired magniture 8.45
Feature 22720 fired magniture 7.41
Feature 18354 fired magniture 6.50
Feature 11984 fired magniture 5.94
Feature 16084 fired magniture 5.45
Feature 22735 fired magniture 5.30


In [12]:
def filter(activations, threshold=0.01) -> tuple[list, int]:
    print("OK")
    total_activations = [1 for x in activations if x >= threshold]
    total_activations = sum(total_activations)
    attributions = [0 if x < threshold else x for x in activations]
    return attributions, total_activations

Let's plot multiple activation vectors

In [25]:
prompt = [
    "Solver the following equation: 44 - x = 32, what is x? 44 - x",
    "Solver the following equation: 44 - x = 32, what is x? 44 - 32"
]

# directions = ["Solving the following mathematical problem. Problem: Calculate the following expression: (4 + 3 * 2 - 1 ) * 4 - 2. Step 1: First, we add 4 and 3 to get 7.",
#              "Solving the following mathematical problem. Problem: Calculate the following expression: (12 + 23 * 4 - 1 ) * 3 - 2. Step 1: First, we add 12 and 23 to get 35.",
#              "Solving the following mathematical problem. Problem: Calculate the following expression: (34 + 2231 * 6 - 1 ) * 34 - 2. Step 1: First, we add 34 and 2231 to get 2265.",
#              "Solving the following mathematical problem. Problem: Calculate the following expression: (12 + 123 * 2 - 1 ) * 412 - 2. Step 1: First, we add 12 and 123 to get 135.",
#              "Solving the following mathematical problem. Problem: Calculate the following expression: (12 + 12445 * 2 - 1 ) * 412 - 2. Step 1: First, we add 12 and 12445 to get 12457.",
#              "Solving the following mathematical problem. Problem: Calculate the following expression: (12 + 12446 * 2 - 1 ) * 412 - 2. Step 1: First, we add 12 and 123 to get 12458.",
#              "Solving the following mathematical problem. Problem: Calculate the following expression: (12 + 12447 * 2 - 1 ) * 412 - 2. Step 1: First, we add 12 and 123 to get 12459.",
#              "Solving the following mathematical problem. Problem: Calculate the following expression: (12 + 1000 * 2 - 1 ) * 412 - 2. Step 1: First, we add 12 and 1000 to get 1012.",]

directions = ["Solving the following mathematical problem. Problem: Calculate the following expression: (4 + 3 * 2 - 1 ) * 4 - 2. Step 1: First, we add 4 and 3",
             "Solving the following mathematical problem. Problem: Calculate the following expression: (12 + 23 * 4 - 1 ) * 3 - 2. Step 1: First, we add 12 and 23",
             "Solving the following mathematical problem. Problem: Calculate the following expression: (34 + 2231 * 6 - 1 ) * 34 - 2. Step 1: First, we add 34 and 2231",
             "Solving the following mathematical problem. Problem: Calculate the following expression: (12 + 123 * 2 - 1 ) * 412 - 2. Step 1: First, we add 12 and 123",
             "Solving the following mathematical problem. Problem: Calculate the following expression: (12 + 12445 * 2 - 1 ) * 412 - 2. Step 1: First, we add 12 and 12445",
             "Solving the following mathematical problem. Problem: Calculate the following expression: (12 + 12446 * 2 - 1 ) * 412 - 2. Step 1: First, we add 12 and 123",
             "Solving the following mathematical problem. Problem: Calculate the following expression: (12 + 12447 * 2 - 1 ) * 412 - 2. Step 1: First, we add 12 and 123",
             "Solving the following mathematical problem. Problem: Calculate the following expression: (12 + 1000 * 2 - 1 ) * 412 - 2. Step 1: First, we add 12 and 1000",]

directions = ["Solving the following mathematical problem. Problem: Calculate the following expression: (4 + 3 * 2 - 1 ) * 4 - 2. Step 1: First, we minus 2 and 1",
             "Solving the following mathematical problem. Problem: Calculate the following expression: (12 + 23 * 4 - 1 ) * 3 - 2. Step 1: First, we minus 4 and 1", 
             "Solving the following mathematical problem. Problem: Calculate the following expression: (34 + 2231 * 6 - 1 ) * 34 - 2. Step 1: First, we minus 6 and 1",
             "Solving the following mathematical problem. Problem: Calculate the following expression: (12 + 123 * 2 - 1 ) * 412 - 2. Step 1: First, we minus 2 and 1",
             "Solving the following mathematical problem. Problem: Calculate the following expression: (12 + 12445 * 2 - 1 ) * 412 - 2. Step 1: First, we minus 2 and 1",
             "Solving the following mathematical problem. Problem: Calculate the following expression: (12 + 12446 * 2 - 1 ) * 412 - 2. Step 1: First, we minus 2 and 1",
             "Solving the following mathematical problem. Problem: Calculate the following expression: (12 + 12447 * 2 - 1 ) * 412 - 2. Step 1: First, we minus 2 and 1",
             "Solving the following mathematical problem. Problem: Calculate the following expression: (12 + 1000 * 2 - 1 ) * 412 - 2. Step 1: First, we minus 2 and 1",]

# hook_id = 'blocks.8.hook_mlp_out.hook_sae_acts_post'
hook_id = "blocks.14.hook_resid_post.hook_sae_acts_post"
_, cache = model.run_with_cache_with_saes(directions, saes=[sae])

# attention_positions = [[-1], [-1], [-1], [-1], [-1], [-1], [-1], [-1]]
# attention_positions = [[-1], [-1], [-1], [-1], [-1], [-1], [-1], [-1]]
attention_positions = [-1, -1, -1, -1, -1, -1, -1, -1]

feature_activation_df = pd.DataFrame(
    cache[hook_id][0, attention_positions[0], :].cpu().tolist(),
    index=[f"feature_{i}" for i in range(sae.cfg.d_sae)],
)
feature_activation_df.columns = ["0"]
n_features = sae.cfg.d_sae
mask = torch.ones(n_features)
# counter of the number of activations for each feature accross all directions
activations_counter = torch.zeros(n_features)
for did, direction in tqdm(enumerate(directions)):
    if did == 0: continue
    activation = cache[hook_id][did, attention_positions[did], :].cpu()
    feature_activation_df[f"{did}"] = (activation.tolist())
    activation_mask = activation > 3.0
    print(f"Actvation mask  {did}: ", activation_mask.nonzero(as_tuple=True)[0])
    activations_counter += activation_mask.int()
    mask = mask * activation_mask

# ex 1, contextual 1
# ->abstraction common: latent variable
# ex 2

# take out the features that are not activated # dataframe
nonzero_counter = activations_counter[activations_counter.nonzero(as_tuple=True)[0]]
nonzero_feature = activations_counter.nonzero(as_tuple=True)[0]
# create dataframe with the number of activations for each feature
feature_df = pd.DataFrame(
    nonzero_counter.tolist(),
    index=[i.item() for i in nonzero_feature],
)
# filter the features that are fired  > 5:
for n_fired in range(len(directions)):
    feature_df = feature_df[feature_df[0] > n_fired]
    print(f"Features fired in more than {n_fired} times: ", feature_df[0].index.tolist())

# sort the features by the number of activations
feature_df = feature_df.sort_values(by=[0], ascending=False)
print("Feature activations: ", feature_df)

print("Mask: ", mask.nonzero(as_tuple=True)[0])

fig = px.line(
    feature_activation_df,
    title="Feature activations for the prompt",
    labels={"index": "Feature", "value": "Activation"},
)

# hide the x-ticks
# fig.update_xaxes(showticklabels=False)
# fig.show()

8it [00:00, 516.38it/s]

Actvation mask  1:  tensor([ 4087,  4954, 10327, 14344, 18312, 21719, 22720])
Actvation mask  2:  tensor([ 4087,  4444,  4954, 10327, 14344, 18312, 21719, 21850, 22720])
Actvation mask  3:  tensor([ 4087,  4444,  4954, 10327, 14344, 18312, 21719, 21850, 22720])
Actvation mask  4:  tensor([  101,   345,  4297,  4410,  4411,  4782,  4783,  4877,  6551,  6878,
         7384,  7410,  8303,  9321, 10738, 11302, 11594, 13107, 14068, 15023,
        15311, 16451, 17808, 17975, 18038, 18758, 18923, 20166, 21021, 21719,
        21792, 23602, 24219])
Actvation mask  5:  tensor([  101,   345,  4297,  4410,  4411,  4782,  4783,  4877,  6551,  6878,
         7384,  7410,  8303,  9321, 10738, 11302, 11594, 13107, 14068, 15023,
        15311, 16451, 17808, 17975, 18038, 18758, 18923, 20166, 21021, 21719,
        21792, 23602, 24219])
Actvation mask  6:  tensor([  101,   345,  4297,  4410,  4411,  4782,  4783,  4877,  6551,  6878,
         7384,  7410,  8303,  9321, 10738, 11302, 11594, 13107, 14068, 1




## Local: Finding Max Activating Examples

In [17]:
# instantiate an object to hold activations from a dataset
from sae_lens import ActivationsStore

# a convenient way to instantiate an activation store is to use the from_sae method
activation_store = ActivationsStore.from_sae(
    model=model,
    sae=sae,
    streaming=True,
    store_batch_size_prompts=8,
    train_batch_size_tokens=4096,
    n_batches_in_buffer=32,
    device=device,
)

IndexError: list index out of range

In [22]:
def list_flatten(nested_list):
    return [x for y in nested_list for x in y]

# A very handy function Neel wrote to get context around a feature activation
def make_token_df(tokens, len_prefix=5, len_suffix=3, model=model):
    str_tokens = [model.to_str_tokens(t) for t in tokens]
    unique_token = [
        [f"{s}/{i}" for i, s in enumerate(str_tok)] for str_tok in str_tokens
    ]

    context = []
    prompt = []
    pos = []
    label = []
    for b in range(tokens.shape[0]):
        for p in range(tokens.shape[1]):
            prefix = "".join(str_tokens[b][max(0, p - len_prefix) : p])
            if p == tokens.shape[1] - 1:
                suffix = ""
            else:
                suffix = "".join(
                    str_tokens[b][p + 1 : min(tokens.shape[1] - 1, p + 1 + len_suffix)]
                )
            current = str_tokens[b][p]
            context.append(f"{prefix}|{current}|{suffix}")
            prompt.append(b)
            pos.append(p)
            label.append(f"{b}/{p}")
    # print(len(batch), len(pos), len(context), len(label))
    return pd.DataFrame(
        dict(
            str_tokens=list_flatten(str_tokens),
            unique_token=list_flatten(unique_token),
            context=context,
            prompt=prompt,
            pos=pos,
            label=label,
        )
    )

In [None]:
# finding max activating examples is a bit harder. To do this we need to calculate feature activations for a large number of tokens
sampling_feature = 10
feature_list = torch.randint(0, sae.cfg.d_sae, (sampling_feature,))
examples_found = 0
all_fired_tokens = []
all_feature_acts = []
all_reconstructions = []
all_token_dfs = []

total_batches = 100
batch_size_prompts = activation_store.store_batch_size_prompts
batch_size_tokens = activation_store.context_size * batch_size_prompts
pbar = tqdm(range(total_batches))
for i in pbar:
    tokens = activation_store.get_batch_tokens()
    tokens_df = make_token_df(tokens)
    tokens_df["batch"] = i

    flat_tokens = tokens.flatten()

    _, cache = model.run_with_cache(
        tokens, stop_at_layer=sae.cfg.hook_layer + 1, names_filter=[sae.cfg.hook_name]
    )
    sae_in = cache[sae.cfg.hook_name]
    feature_acts = sae.encode(sae_in).squeeze()

    feature_acts = feature_acts.flatten(0, 1)
    fired_mask = (feature_acts[:, feature_list]).sum(dim=-1) > 0
    fired_tokens = model.to_str_tokens(flat_tokens[fired_mask])
    reconstruction = feature_acts[fired_mask][:, feature_list] @ sae.W_dec[feature_list]

    token_df = tokens_df.iloc[fired_mask.cpu().nonzero().flatten().numpy()]
    all_token_dfs.append(token_df)
    all_feature_acts.append(feature_acts[fired_mask][:, feature_list])
    all_fired_tokens.append(fired_tokens)
    all_reconstructions.append(reconstruction)

    examples_found += len(fired_tokens)
    # print(f"Examples found: {examples_found}")
    # update description
    pbar.set_description(f"Examples found: {examples_found}")

# flatten the list of lists
all_token_dfs = pd.concat(all_token_dfs)
all_fired_tokens = list_flatten(all_fired_tokens)
all_reconstructions = torch.cat(all_reconstructions)
all_feature_acts = torch.cat(all_feature_acts)

## Getting Feature Activation Histogram

Next, we can generate the feature activation histogram (just as we saw on the dashboards above) and display the list of max-activating examples we just generated. We'll just do this for the first feature in our random set (index 0).

In [None]:
feature_acts_df = pd.DataFrame(
    all_feature_acts.detach().cpu().tolist(),
    columns=[f"feature_{i}" for i in feature_list],
)
feature_acts_df.shape

In [None]:
feature_idx = 0
# get non-zero activations

all_positive_acts = all_feature_acts[all_feature_acts[:, feature_idx] > 0][
    :, feature_idx
].detach()
prop_positive_activations = (
    100 * len(all_positive_acts) / (total_batches * batch_size_tokens)
)

px.histogram(
    all_positive_acts.cpu().tolist(),
    nbins=50,
    title=f"Histogram of positive activations - {prop_positive_activations:.3f}% of activations were positive",
    labels={"value": "Activation"},
    width=800,
)

In [25]:
top_10_activations = feature_acts_df.sort_values(
    f"feature_{feature_list[0]}", ascending=False
).head(10000)
# save the sub dataframe
top_10_df = all_token_dfs.iloc[
    top_10_activations.index
]  # TODO: double check this is working correctly
# top_10_df.to_csv("top_10_activations.csv", index=False)

## Getting the Top 10 Logit Weights

As a final step, we'll generate the top 10 logit weights--that is, we'll see what tokens each of the features in our set is promoting most strongly.

Note it's important to fold layer norm (by default SAE Lens loads Transformers with folder layer norm but sometimes we turn preprocessing off to save GPU ram and this would affect the logit weight histograms a little bit).

In [None]:
print(f"Shape of the decoder weights {sae.W_dec.shape})")
print(f"Shape of the model unembed {model.W_U.shape}")
# convert datatype from float16 to float32
projection_matrix = sae.W_dec.float() @ model.W_U.float()
print(f"Shape of the projection matrix {projection_matrix.shape}")

# then we take the top_k tokens per feature and decode them
top_k = 10
# let's do this for 100 random features
_, top_k_tokens = torch.topk(projection_matrix[feature_list], top_k, dim=1)


feature_df = pd.DataFrame(
    top_k_tokens.cpu().numpy(), index=[f"feature_{i}" for i in feature_list]
).T
feature_df.index = [f"token_{i}" for i in range(top_k)]
feature_df.applymap(lambda x: model.tokenizer.decode(x))

## Feature Steering

In [None]:
from tqdm import tqdm
from functools import partial


def find_max_activation(model, sae, activation_store, feature_idx, num_batches=100):
    """
    Find the maximum activation for a given feature index. This is useful for
    calibrating the right amount of the feature to add.
    """
    max_activation = 0.0

    pbar = tqdm(range(num_batches))
    for _ in pbar:
        tokens = activation_store.get_batch_tokens()

        _, cache = model.run_with_cache(
            tokens,
            stop_at_layer=sae.cfg.hook_layer + 1,
            names_filter=[sae.cfg.hook_name],
        )
        sae_in = cache[sae.cfg.hook_name]
        feature_acts = sae.encode(sae_in).squeeze()

        feature_acts = feature_acts.flatten(0, 1)
        batch_max_activation = feature_acts[:, feature_idx].max().item()
        max_activation = max(max_activation, batch_max_activation)

        pbar.set_description(f"Max activation: {max_activation:.4f}")

    return max_activation


def steering(
    activations, hook, steering_strength=1.0, steering_vector=None, max_act=1.0
):
    # Note if the feature fires anyway, we'd be adding to that here.
    return activations + max_act * steering_strength * steering_vector


def generate_with_steering(
    model,
    sae,
    prompt,
    steering_feature,
    max_act,
    steering_strength=1.0,
    max_new_tokens=95,
):
    input_ids = model.to_tokens(prompt, prepend_bos=sae.cfg.prepend_bos)

    steering_vector = sae.W_dec[steering_feature].to(model.cfg.device)

    steering_hook = partial(
        steering,
        steering_vector=steering_vector,
        steering_strength=steering_strength,
        max_act=max_act,
    )

    # standard transformerlens syntax for a hook context for generation
    with model.hooks(fwd_hooks=[(sae.cfg.hook_name, steering_hook)]):
        output = model.generate(
            input_ids,
            max_new_tokens=max_new_tokens,
            temperature=0.7,
            top_p=0.9,
            stop_at_eos=False if device == "mps" else True,
            prepend_bos=sae.cfg.prepend_bos,
        )

    return model.tokenizer.decode(output[0])


# Choose a feature to steer
# features = [7928,  8732, 14285, 17533]
features = [ 1727, 14975, 53332]
steering_feature = steering_feature = 1727  # Choose a feature to steer towards

# Find the maximum activation for this feature
# max_act = find_max_activation(model, sae, activation_store, steering_feature)
max_act = 1.0
print(f"Maximum activation for feature {steering_feature}: {max_act:.4f}")

# Generate text without steering for comparison
prompt = "Solving the following mathematical problem. Problem: Calculate the following expression: (10222 + 23123123 * 4 - 1 ) * 5 - 6. Step 1:"
normal_text = model.generate(
    prompt,
    max_new_tokens=95,
    stop_at_eos=False if device == "mps" else True,
    prepend_bos=sae.cfg.prepend_bos,
)

print("\nNormal text (without steering):")
print(normal_text)

# Generate text with steering
steered_text = generate_with_steering(
    model, sae, prompt, steering_feature, max_act, steering_strength=10.0
)
print("Steered text:")
print(steered_text)

In [None]:
# Experiment with different steering strengths
prompt = "Solving the following mathematical problem. Problem: Calculate the following expression: (10222 + 23123123 * 4 - 1 ) * 5 - 6. Step 1:"

# print("\nExperimenting with different steering strengths:")
# for strength in [-4.0, -2.0, 0.5, 2.0, 4.0]:
#     steered_text = generate_with_steering(
#         model, sae, prompt, steering_feature, max_act, steering_strength=strength
#     )
#     print(f"\nSteering strength {strength}:")
#     print(steered_text)
    
print("\nExperimenting with different steering features:")
# for fid in [4636,  8439,  9652, 10868, 11551, 18909, 31366, 32379]:
for fid in [ 1727, 14975, 53332]:
    max_act = find_max_activation(model, sae, activation_store, fid)
    print(f"Maximum activation for feature {fid}: {max_act:.4f}")
    for strength in [1.0, 2.0, 4.0, 8.0]:
        print(f"\nSteering strength {strength}:")
        steered_text = generate_with_steering(
            model, sae, prompt, steering_feature, max_act, steering_strength=strength
        )
        print(steered_text)
# x = y / z
# y = x * z
# y = x / z
# worked well for 8439

# anti-semantic but logical pattern still works