In [1]:
import os
import torch
from tqdm import tqdm
import plotly.express as px

torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7f8260ab7fd0>

In [2]:
# For the most part I'll try to import functions and classes near where they are used
# to make it clear where they come from.

if torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Device: {device}")

Device: cuda


In [3]:
# from huggingface_hub import notebook_login
# notebook_login()

In [5]:
from datasets import load_dataset
from transformer_lens import HookedTransformer
from sae_lens import SAE

model = HookedTransformer.from_pretrained("gemma-2-2b", device=device)
print("Model loaded successfully")

# the cfg dict is returned alongside the SAE since it may contain useful information for analysing the SAE (eg: instantiating an activation store)
# Note that this is not the same as the SAEs config dict, rather it is whatever was in the HF repo, from which we can extract the SAE config dict
# We also return the feature sparsities which are stored in HF for convenience.
sae, cfg_dict, sparsity = SAE.from_pretrained(
    release="gemma-scope-2b-pt-res-canonical",  # see other options in sae_lens/pretrained_saes.yaml
    sae_id="layer_12/width_16k/canonical",  # won't always be a hook point
    device=device,
)
print("SAE loaded successfully")



Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]



Loaded pretrained model gemma-2-2b into HookedTransformer
Model loaded successfully
SAE loaded successfully


In [8]:
# def zero_abl_hook(activation, hook):
#     return torch.zeros_like(activation)

def direct(activation, hook):
    return activation

def clamping(activation, hook):
    encoded = sae.encode(activation)

    # The activations are now stored in intermediate_outputs
    # Modify activations
    # print(encoded.shape)
    # print(encoded[:, :, torch.argmax(encoded)])
    # encoded[:, :, :] = 0.0
    encoded[0, :, 6] = 150.0

    sae_out = sae.decode(encoded)

    return sae_out

In [9]:
example_prompt = "Between dog and an fish, a dolphin is closer to a"

with model.hooks(
    fwd_hooks=[(sae.cfg.hook_name, direct)],
):
    
    input_tokens = model.to_tokens(example_prompt, prepend_bos=True)

    generated_tokens = model.generate(input_tokens, max_new_tokens=100, do_sample=True)

    generated_tokens = model.tokenizer.decode(generated_tokens[0])
    print(generated_tokens)

with model.hooks(
    fwd_hooks=[(sae.cfg.hook_name, clamping)],
):
    input_tokens = model.to_tokens(example_prompt, prepend_bos=True)

    generated_tokens = model.generate(input_tokens, max_new_tokens=100, do_sample=True)

    generated_tokens = model.tokenizer.decode(generated_tokens[0])
    print(generated_tokens)



  0%|          | 0/100 [00:00<?, ?it/s]

<bos>Between dog and an fish, a dolphin is closer to a shark. That is a more logical explanation. And not one that makes any sense whatsoever.
Because... they're sharks.
Because they are marine life and the shark lives in water.
Probably because it lives in the ocean, or something like that.
Because they are similar... it may not make sense, but just think about it. Kangaroos look nothing and similar to a lamp but a lamp and a kangaroo are animals. Watch how there two feet and legs move in the same


  0%|          | 0/100 [00:00<?, ?it/s]

<bos>Between dog and an fish, a dolphin is closer to a Fish.
Between Dog and, I
Between Dog and I
I Watch III and I Watch I
I Watch I
I Watch I
I Watch I Watch,andelson I Watch I

The glory of which the gun.
I watch the Moon.
Watch I
Thet and I watch and I
I
I
I
Iey
i
I
i
I
i
y

I
I
NAMES

I

I, I

<h6>
