---
## general setup (don't bother)

In [None]:
%%capture
!uv pip install --upgrade sae-lens transformer-lens sae-dashboard

In [None]:
from torch import Tensor
from transformer_lens import utils
from functools import partial
from tqdm import tqdm
from jaxtyping import Int, Float

import torch, pathlib, pandas as pd
import huggingface_hub as hf_hub, safetensors as st

# device setup
if torch.backends.mps.is_available():
  device = 'mps'
else:
  device = 'cuda' if torch.cuda.is_available() else 'cpu'

DEBUG=True
ONCE=False

In [None]:
from sae_lens import SAE, HookedSAETransformer
from sae_lens.toolkit.pretrained_saes_directory import get_pretrained_saes_directory
from sae_lens.toolkit.pretrained_saes import get_gpt2_res_jb_saes
from datasets import load_dataset
from transformer_lens.utils import tokenize_and_concatenate

---
## metadata for SAE exploration

In [None]:
# TODO: Make this nicer.
df = pd.DataFrame.from_records({k: v.__dict__ for k, v in get_pretrained_saes_directory().items()}).T
df.drop(columns=['expected_var_explained', 'expected_l0', 'config_overrides', 'conversion_func'], inplace=True)

# currently only layer 6 works pretty well
layer = 20
MODEL = 'google/gemma-2-9b-it'
SAE_ID = 'gemma-scope-9b-it-res'

df

In [None]:
if DEBUG:
  print(f'SAEs in the {SAE_ID}')
  for k, v in df.loc[df.release == SAE_ID, 'saes_map'].values[0].items():print(f'SAE id: {k} for hook point: {v}')

---
## models

We will be using [GemmaScope](https://huggingface.co/google/gemma-scope-9b-pt-res/tree/main) and [google/gemma-2-9b-it](https://huggingface.co/google/gemma-2-9b-it) for exploration

In [None]:
%%capture
if not ONCE:model = HookedSAETransformer.from_pretrained(MODEL, device=device); ONCE=True

In [None]:
sae, cfg_dict, sparsity = SAE.from_pretrained(
  release=SAE_ID, device=device,
  sae_id=f'layer_{layer}/width_131k/average_l0_81',  # test with L0_81
)

In [None]:
# get hook point
hook_point = sae.cfg.hook_name

# setup some cfg
sae.cfg.neuronpedia_id = MODEL

# print out dict
sae.cfg.__dict__

In [None]:
dataset = load_dataset(path='NeelNanda/pile-10k', split='train', streaming=False)

token_dataset = tokenize_and_concatenate(
  dataset=dataset,
  tokenizer=model.tokenizer,
  streaming=True,
  max_length=sae.cfg.context_size,
  add_bos_token=sae.cfg.prepend_bos,
)

In [None]:
sv_prompt = """The chiffonier stood a few feet from the foot of the bed. He had emptied the drawers into cartons that morning, which were in the living room."""

sv_logits, cache = model.run_with_cache(sv_prompt, prepend_bos=True)
tokens = model.to_tokens(sv_prompt)
print(tokens)

# get the feature activations from our SAE
sv_feature_acts = sae.encode(cache[hook_point])

# get sae_out
sae_out = sae.decode(sv_feature_acts)

# print out the top activations, focus on the indices
print(torch.topk(sv_feature_acts, 3))
top_indices = torch.topk(sv_feature_acts, 3).indices.tolist()

In [None]:
from IPython.display import IFrame, HTML

# get a random feature from the SAE
feature_idx = torch.randint(0, sae.cfg.d_sae, (1,)).item()

html_template = 'https://neuronpedia.org/{}/{}/{}?embed=true&embedexplanation=true&embedplots=true&embedtest=true'

def get_dashboard_html(sae_release='gemma-2-9b-it', sae_id=f'{layer}-gemmascope-res-131k', feature_idx=0):
  print((result := html_template.format(sae_release, sae_id, feature_idx)))
  return result


html = get_dashboard_html(feature_idx=43499)
IFrame(html, width=1800, height=1800)

---
## steering logics

In [None]:
def find_max_activation(model, sae, activation_store, feature_idx, num_batches=100):
  """
  Find the maximum activation for a given feature index. This is useful for
  calibrating the right amount of the feature to add.
  """
  max_activation = 0.0

  pbar = tqdm(range(num_batches))
  for _ in pbar:
    tokens = activation_store.get_batch_tokens()

    _, cache = model.run_with_cache(tokens, stop_at_layer=sae.cfg.hook_layer + 1, names_filter=[sae.cfg.hook_name])
    sae_in = cache[sae.cfg.hook_name]
    feature_acts = sae.encode(sae_in).squeeze()

    feature_acts = feature_acts.flatten(0, 1)
    batch_max_activation = feature_acts[:, feature_idx].max().item()
    max_activation = max(max_activation, batch_max_activation)

    pbar.set_description(f'Max activation: {max_activation:.4f}')

  return max_activation


def steering(activations, hook, steering_strength=1.0, steering_vector=None, max_act=1.0):
  return activations + max_act * steering_strength * steering_vector

def generate(model, prompt,
             max_new_tokens=95, temperature=0.7, top_p=0.9):
    output = model.generate(sv_prompt,
                            max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p,
                            stop_at_eos=False if device == 'mps' else True, prepend_bos=sae.cfg.prepend_bos,
                            return_type="tensor")
    return model.tokenizer.decode(output[0]), output

def generate_with_steering(model, sae, prompt, steering_feature, max_act, steering_strength=1.0, max_new_tokens=95):
  input_ids = model.to_tokens(prompt, prepend_bos=sae.cfg.prepend_bos)

  steering_vector = sae.W_dec[steering_feature].to(model.cfg.device)

  steering_hook = partial(
    steering, steering_vector=steering_vector, steering_strength=steering_strength, max_act=max_act
  )

  # standard transformerlens syntax for a hook context for generation
  with model.hooks(fwd_hooks=[(sae.cfg.hook_name, steering_hook)]):
    output = model.generate(
      input_ids,
      max_new_tokens=max_new_tokens,
      temperature=0.7,
      top_p=0.9,
      stop_at_eos=False if device == 'mps' else True,
      prepend_bos=sae.cfg.prepend_bos,
    )

  return model.tokenizer.decode(output[0]), output

def beautify_generations(model, prompt, sae, output, color="blue"):
  input_ids = model.to_tokens(prompt, prepend_bos=sae.cfg.prepend_bos)
  # Decode the prompt and generated text separately
  prompt_len = len(input_ids[0])
  prompt_tokens = output[0][:prompt_len]
  generated_tokens = output[0][prompt_len:]

  prompt_text = model.tokenizer.decode(prompt_tokens)
  generated_text = model.tokenizer.decode(generated_tokens)

  # Create HTML with different colors
  html_output = f"""
<div style="font-family: monospace;">
    <span style="color: black;">{prompt_text}</span>
    <span style="color: {color}; font-weight: bold;">{generated_text}</span>
</div>
"""

  # Display the colored output
  display(HTML(html_output))

---
## Play

In [None]:
# Choose a feature to steer
steering_feature = steering_feature = 43499 # related to furniture

# Find the maximum activation for this feature
# NOTE: we could also get the max activation from Neuronpedia (https://www.neuronpedia.org/api-doc#tag/lookup/GET/api/feature/{modelId}/{layer}/{index})
max_act = 50.66
print(f'Maximum activation for feature {steering_feature}: {max_act:.4f}')

# Generate text without steering for comparison
normal_text, normal_output = generate(model, sv_prompt)
beautify_generations(model, sv_prompt, sae, normal_output)

if DEBUG: print('\nNormal text (without steering):', normal_text)

# Generate text with steering
steered_text, steered_output = generate_with_steering(model, sae, sv_prompt, steering_feature, 
                                                      max_act, steering_strength=3.0)
beautify_generations(model, sv_prompt, sae, steered_output, color="red")
if DEBUG: print('Steered text:\n', steered_text)