In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from typing import *
import einops
import transformer_lens
from functools import partial
import sae_vis
from IPython.display import HTML, display
import jaxtyping as jt

import datasets
import torch
# import jax # just for tree map whcih i deleted
import gc
from tqdm.auto import tqdm
from IPython.display import clear_output
from tqdm.auto import tqdm

torch.set_grad_enabled(False)

clear_output()

In [3]:
gpt2xl = transformer_lens.HookedSAETransformer.from_pretrained_no_processing("gpt2-xl")
# SAE was trained without TL's nice things.

clear_output()

In [4]:
sae = transformer_lens.HookedSAE.from_pretrained("gpt2-xl-saex-resid-pre-l20")



In [5]:
data = datasets.load_dataset("Elriggs/openwebtext-100k", streaming=False)
data = data["train"]

In [6]:
iter_data = iter(data)

In [7]:
total_loss = 0.0
total_sae_loss = 0.0
total_toks = 0

for _ in tqdm(range(200)):
    text = next(iter_data)["text"]
    tokens = gpt2xl.to_tokens(text)[:, :128]
    total_toks += tokens[:, :-1].numel()
    site = "blocks.20.hook_resid_pre"
    logits, all_acts = gpt2xl.run_with_cache(
        tokens,
        names_filter = site,
    )
    acts = all_acts[site]
    def get_neglogprobs(logits, tokens):
        neglogprobs = -logits.log_softmax(dim=-1)[
            torch.arange(logits.shape[0])[:, None],
            torch.arange(logits.shape[1]-1)[None],
            tokens[:, 1:]
        ]
        return neglogprobs
    neglogprobs = get_neglogprobs(logits, tokens)
    total_loss += neglogprobs.sum().item()

    sae_logits = gpt2xl.run_with_saes(
        tokens,
        saes = [sae],
    )
    
    sae_loss = get_neglogprobs(sae_logits, tokens)
    total_sae_loss += sae_loss.sum().item()



  0%|          | 0/200 [00:00<?, ?it/s]

In [8]:
total_loss / total_toks, total_sae_loss / total_toks

(2.7160975671002245, 2.9482859949427325)

In [116]:
# 0.23 loss diff is much bigger than claimed... let's press on anyways.

In [9]:
sae_acts_pre_hook_name = "blocks.20.hook_resid_pre.hook_sae_acts_pre"
sae_logits, sae_cache = gpt2xl.run_with_cache_with_saes(
    tokens,
    saes = [sae],
    names_filter=sae_acts_pre_hook_name
)
sae_acts_pre = sae_cache[sae_acts_pre_hook_name]
print(sae_acts_pre.shape)

torch.Size([1, 128, 51200])


In [10]:
# Print avg L0

((sae_acts_pre > 0).sum(dim=-1).float()).mean()

# 24

tensor(24.7578, device='cuda:0')

In [11]:
# from SAE vis demo.
SEQ_LEN = 128

# Tokenize the data (using a utils function) and shuffle it
tokenized_data = transformer_lens.utils.tokenize_and_concatenate(data, gpt2xl.tokenizer, max_length=SEQ_LEN) # type: ignore
tokenized_data = tokenized_data.shuffle(42)

# Get the tokens as a tensor
all_tokens = tokenized_data["tokens"]
assert isinstance(all_tokens, torch.Tensor)

print(all_tokens.shape)

Map (num_proc=10):   0%|          | 0/100000 [00:00<?, ? examples/s]

Token indices sequence length is longer than the specified maximum sequence length for this model (59374 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (56534 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (55671 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (58948 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (59253 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence

torch.Size([888650, 128])


In [12]:
sae_vis_sae_cfg = sae_vis.model_fns.AutoEncoderConfig(
    d_in=sae.cfg.d_in,
    d_hidden=sae.cfg.d_sae
)

In [15]:
# ugh gross TODO(conmy): raise issues to standardise the SAEs used by various libraries

sae_vis_sae = sae_vis.model_fns.AutoEncoder(sae_vis_sae_cfg)
sae_vis_sae.load_state_dict(sae.state_dict())


<All keys matched successfully>

In [21]:
all_tokens.device

device(type='cpu')

In [16]:
def get_sae_vis_data(features):
    # Specify the hook point you're using, the features you're analyzing, and the batch size for gathering activations
    sae_vis_config = sae_vis.SaeVisConfig(
        hook_point = site,
        features = features,
        batch_size = 2048,
        verbose = True,
    )

    # Gather the feature data
    return sae_vis.SaeVisData.create(
        encoder = sae_vis_sae,
        # encoder_B = encoder_B,
        model = gpt2xl,
        tokens = all_tokens, # type: ignore
        cfg = sae_vis_config,
    )

feature_idx = 126
sae_vis_data = get_sae_vis_data([feature_idx])

Forward passes to cache data for vis:   0%|          | 0/32 [00:00<?, ?it/s]

Extracting vis data from cached data:   0%|          | 0/1 [00:00<?, ?it/s]

OutOfMemoryError: CUDA out of memory. Tried to allocate 50.00 MiB. GPU 

In [124]:
import time

def save_sae_vis_html(sae_vis_data):
    # Save as HTML file & open in browser (or not, if in Colab)
    filenames = []
    for feature_idx in sae_vis_data.feature_data_dict.keys():
        filename = f"feature_vis_demo_{int(1000*time.time())}.html"
        sae_vis_data.save_feature_centric_vis(filename, feature_idx=feature_idx)
        filenames.append(filename)
    return filenames

filename = save_sae_vis_html(sae_vis_data)[0]  # Download to visualize.

Saving feature-centric vis: 100%|██████████| 1/1 [00:00<00:00, 14.14it/s]


In [23]:
gpt2xl.W_E.device

device(type='cuda', index=0)

In [22]:
# Using the AF post

pos_prompt = "Anger"  # @param {"type": "string"}
neg_prompt = "Calm"  # @param {"type": "string"}

pos_tokens = gpt2xl.to_tokens(pos_prompt, prepend_bos=True)
neg_tokens = gpt2xl.to_tokens(neg_prompt, prepend_bos=True)
assert pos_tokens.shape == neg_tokens.shape, (pos_tokens.shape, "!=" , neg_tokens.shape)

gpt2xl.reset_hooks()
_, cache = gpt2xl.run_with_cache(
    pos_tokens,
    names_filter = site,
)
pos_vec = cache[site][0, :]

gpt2xl.reset_hooks()
_, neg_cache = gpt2xl.run_with_cache(
    neg_tokens,
    names_filter = site,
)

neg_vec = neg_cache[site][0, :]
anger_steering_vec = 20*(pos_vec - neg_vec)

def activation_generation_hook(
    clean_activation: jt.Float[torch.Tensor, "Batch Seq *Dim"],
    hook: Any,
    indices: slice,
    v: jt.Float[torch.Tensor, "SubSeq *Dim"],
    debug: bool = False,
) -> jt.Float[torch.Tensor, "Batch Seq Dim"]:
  """TransformerLens hook only impacting prompt not rollout."""

  if clean_activation.shape[1] == 1:
    # Doing autoregression. No injection
    return clean_activation

  if debug:
    print("NORM PRE ADD", clean_activation[:, indices].norm(dim=-1))

  clean_activation[:, indices] += v

  if debug:
    print("NORM POST ADD", clean_activation[:, indices].norm(dim=-1))

  return clean_activation

prompt = "I think you're"
torch.random.manual_seed(100)

tokens = gpt2xl.to_tokens(prompt)

def get_steered_completion(
    tokens,
    steering_vec,
    indices,
):
    with gpt2xl.hooks(
        fwd_hooks=[
            (site, partial(activation_generation_hook, v=steering_vec, indices=indices))
        ]
    ):
        output = gpt2xl.generate(
            tokens,
            max_new_tokens=30,  # Params in Turner blog post
            top_p=0.3,
            temperature=1.0,
            freq_penalty=1.0,
            return_type="tensor"
        )
    return gpt2xl.to_string(output)

output = get_steered_completion(tokens, steering_vec=anger_steering_vec, indices=slice(0, 3))

  0%|          | 0/30 [00:00<?, ?it/s]

OutOfMemoryError: CUDA out of memory. Tried to allocate 2.00 MiB. GPU 

In [126]:
print(output)

["<|endoftext|>I think you're a bad person. You should be shot, killed, or some other form of humiliation for what you did to me. I am not responsible for the"]


In [127]:
def is_angry(prompt: str, verbose: bool=False):
    anger_words = ["kill", "anger"]  # There's a bigger list in our AF post not here.
    for word in anger_words:
      if word.lower() in prompt.lower():
        if verbose:
            print(word)
        return True
    return False

In [128]:
is_angry(output[0], True)

kill


True

In [129]:
many_tokens = einops.repeat(
   tokens,
   "1 Seq -> Batch Seq",
   Batch=100,
)

In [130]:
many_outputs = get_steered_completion(
   many_tokens,
   steering_vec=anger_steering_vec,
   indices = slice(0, 3)
)

100%|██████████| 30/30 [00:02<00:00, 11.69it/s]


In [131]:
many_angry = [is_angry(prompt, False) for prompt in many_outputs]
many_angry.count(True) / len(many_angry)

0.01

In [132]:
angry_acts = gpt2xl.run_with_cache_with_saes(
   "Anger",
   names_filter=sae_acts_pre_hook_name,
   saes = [sae],
)[1][sae_acts_pre_hook_name]

In [133]:
most_firing_features = sorted(
   enumerate(angry_acts[0, -1].tolist()), key=lambda x:-x[1]
)

In [134]:
top20_firing_features = [x for x, _ in most_firing_features[:20]]

In [135]:
# What are we going to do now?
#
# Are there several anger features (cool).
#
# Could also look into the Eiffel Tower is in Rome steering vector.

In [136]:
pos_prompt = "The Eiffel Tower is in Rome"  # @param {"type": "string"}
neg_prompt = "The Eiffel Tower is in France"  # @param {"type": "string"}

pos_tokens = gpt2xl.to_tokens(pos_prompt, prepend_bos=True)
neg_tokens = gpt2xl.to_tokens(neg_prompt, prepend_bos=True)
assert pos_tokens.shape == neg_tokens.shape, (pos_tokens.shape, "!=" , neg_tokens.shape)

gpt2xl.reset_hooks()
_, cache = gpt2xl.run_with_cache(
    pos_tokens,
    names_filter = site,
)
pos_vec = cache[site][0, :]

gpt2xl.reset_hooks()
_, neg_cache = gpt2xl.run_with_cache(
    neg_tokens,
    names_filter = site,
)

neg_vec = neg_cache[site][0, :]
rome_steering_vec = (20*(pos_vec - neg_vec))[-1:, :]
# I hope this works

tokens = "To see the eiffel tower, people flock to"

In [137]:
output = get_steered_completion(tokens, steering_vec=rome_steering_vec, indices=slice(pos_tokens.shape[-1]-1, pos_tokens.shape[-1]))

print(output)

100%|██████████| 30/30 [00:02<00:00, 14.43it/s]

["<|endoftext|>To see the eiffel tower, people flock to France's most famous landmark.\n\nRome's historic center is filled with statues and monuments, but one of the most iconic images of France's"]





In [138]:
top20_sae_vis_data = get_sae_vis_data(top20_firing_features)
top20_filenames = save_sae_vis_html(top20_sae_vis_data)

Forward passes to cache data for vis:  19%|█▉        | 6/32 [00:10<00:47,  1.83s/it]

In [None]:
rome_sae_logits, rome_sae_cache = gpt2xl.run_with_cache_with_saes(
    "The Eiffel Tower is in Rome",
    saes = [sae],
    names_filter=sae_acts_pre_hook_name
)

In [None]:
rome_sae_acts = rome_sae_cache[sae_acts_pre_hook_name]

In [None]:
most_firing_features = sorted(
   enumerate(rome_sae_acts[0, -1].tolist()), key=lambda x:-x[1]
)

In [None]:
top20_firing_features = [x for x, _ in most_firing_features[:20]]

In [None]:
rome_top20_sae_vis_data = get_sae_vis_data(top20_firing_features)
top20_filenames = save_sae_vis_html(rome_top20_sae_vis_data)

Forward passes to cache data for vis: 100%|██████████| 32/32 [00:58<00:00,  1.84s/it]

Forward passes to cache data for vis: 100%|██████████| 32/32 [01:00<00:00,  1.89s/it]
Extracting vis data from cached data: 100%|██████████| 20/20 [01:00<00:00,  3.02s/it]
Saving feature-centric vis: 100%|██████████| 20/20 [00:01<00:00, 15.57it/s]
Saving feature-centric vis: 100%|██████████| 20/20 [00:01<00:00, 15.58it/s]
Saving feature-centric vis: 100%|██████████| 20/20 [00:01<00:00, 15.71it/s]
Saving feature-centric vis: 100%|██████████| 20/20 [00:01<00:00, 13.16it/s]
Saving feature-centric vis: 100%|██████████| 20/20 [00:01<00:00, 15.43it/s]
Saving feature-centric vis: 100%|██████████| 20/20 [00:01<00:00, 15.62it/s]
Saving feature-centric vis: 100%|██████████| 20/20 [00:01<00:00, 15.54it/s]
Saving feature-centric vis: 100%|██████████| 20/20 [00:01<00:00, 12.86it/s]
Saving feature-centric vis: 100%|██████████| 20/20 [00:01<00:00, 15.62it/s]
Saving feature-centric vis: 100%|██████████| 20/20 [00:01<00:00, 15.55it/s]
Saving feature-centric vis: 100%|██████████| 20/20 [00:01<00:00, 15.

In [None]:
top20_firing_features

[44460,
 30211,
 42906,
 48631,
 8632,
 42478,
 10730,
 24296,
 45127,
 5630,
 23976,
 24364,
 33218,
 46938,
 5253,
 30603,
 46156,
 27848,
 50584,
 6822]

In [None]:
top20_filenames

['feature_vis_demo_1715307125924.html',
 'feature_vis_demo_1715307127646.html',
 'feature_vis_demo_1715307129388.html',
 'feature_vis_demo_1715307131117.html',
 'feature_vis_demo_1715307133087.html',
 'feature_vis_demo_1715307134825.html',
 'feature_vis_demo_1715307136561.html',
 'feature_vis_demo_1715307138308.html',
 'feature_vis_demo_1715307140313.html',
 'feature_vis_demo_1715307142032.html',
 'feature_vis_demo_1715307143771.html',
 'feature_vis_demo_1715307145476.html',
 'feature_vis_demo_1715307147453.html',
 'feature_vis_demo_1715307149173.html',
 'feature_vis_demo_1715307150891.html',
 'feature_vis_demo_1715307152598.html',
 'feature_vis_demo_1715307154314.html',
 'feature_vis_demo_1715307156281.html',
 'feature_vis_demo_1715307157982.html',
 'feature_vis_demo_1715307159688.html']

You may have a better method of visualizing these dashboards -- I just download and view in Chrome.

Interestingly, the first three most active features are high density and uninterpretable. But the 9th highest activating feature is an Italian feature!

![https://i.imgur.com/Y7eavwv.png]()

In [None]:
# 150 chosen by making the number big, then easing it off
sae_steering_vec = sae.state_dict()["W_dec"][45127][None] * 150

In [None]:
tokens

'To see the eiffel tower, people flock to'

In [None]:
torch.manual_seed(1)
output = get_steered_completion(torch.cat([gpt2xl.to_tokens(tokens) for _ in range(20)], dim=0), steering_vec=sae_steering_vec, indices=slice(pos_tokens.shape[-1]-2, pos_tokens.shape[-1]))

100%|██████████| 30/30 [00:02<00:00, 13.61it/s]


In [None]:
print("\n\n".join(output))

<|endoftext|>To see the eiffel tower, people flock to the Piazza San Marco.

The grand boulevard in front of the city's main cathedral is packed with tourists, eager to take in

<|endoftext|>To see the eiffel tower, people flock to a cafe in the town of Siena.

It's a place where tourists come to eat and drink and relax. It's also where

<|endoftext|>To see the eiffel tower, people flock to it, but in the desert of Jordan there is no such thing.

It's a strange sight: an empty plain surrounded by high sand d

<|endoftext|>To see the eiffel tower, people flock to the Palazzo Vecchio in Rome.

In this photo taken on Thursday, Oct. 24, 2014, a man walks past a large

<|endoftext|>To see the eiffel tower, people flock to the famous bridge.

The ancient bridge is a tourist attraction in its own right, but it's also a crucial link in the city's water

<|endoftext|>To see the eiffel tower, people flock to a few places in the city.

The Eiffel Tower is one of those places. The French-inspired 

In [None]:
# There sure are a lot of Italian reference there!