In [1]:
try:
  import google.colab
  %pip install -q  sae-lens transformer-lens
except:
  pass

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m50.4/50.4 kB[0m [31m3.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m143.1/143.1 kB[0m [31m9.5 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m127.4/127.4 kB[0m [31m11.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m920.0/920.0 kB[0m [31m52.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m175.1/175.1 kB[0m [31m16.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m57.7/57.7 kB[0m [31m5.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m739.7/739.7 kB[0m [31m49.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m527.3/527.3 kB[0m [31m38.8 MB/s[

In [3]:
import torch
from transformer_lens import HookedTransformer
import numpy as np
from typing import Dict, Union, List
from jaxtyping import Float
from functools import partial
import os
from tqdm import tqdm
import plotly.express as px
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookPoint,
)
from google.colab import drive
from collections import defaultdict

In [None]:
!huggingface-cli login

#### Load model and SAE

In [5]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
from transformer_lens import HookedTransformer
from sae_lens import SAE
from sae_lens.toolkit.pretrained_saes import get_gpt2_res_jb_saes

layer = 10

# get the SAE for this layer
sae, cfg_dict, _ = SAE.from_pretrained(
    release = "gemma-scope-27b-pt-res-canonical",
    sae_id = f"layer_{layer}/width_131k/canonical",
    device = device
)
hook_point = sae.cfg.hook_name
print(hook_point)

In [None]:
model = HookedTransformer.from_pretrained("google/gemma-2-27b-it",
                                          dtype='float16',
                                          device = device)

#### Load prompts

In [None]:
try:
    import google.colab
        drive.mount('/content/drive')
        folder_path = '/content/drive/My Drive/data'
except:
    folder_path = 'data'

In [None]:
def load_json_files(folder_path):
    json_data = []
    for filename in os.listdir(folder_path):
        if filename.endswith('.json'):
            with open(os.path.join(folder_path, filename), 'r') as f:
                json_data.append(json.load(f))
    return json_data

json_data = load_json_files(folder_path)

In [None]:
def extract_data(json_data):
    result = {"ans_always_a": [], "suggested_answer": []}

    pattern = re.compile(r'-biastype([a-zA-Z_]+)-')

    for data in json_data:
        match = pattern.search(data['filename'])
        if match:
            bias_type = match.group(1)

            outputs_0 = data['outputs']['0']
            outputs_1 = data['outputs']['1']

            if bias_type == 'ans_always_a':
                # Criteria 1: "0" y_pred == 0 and "1" y_pred == 1
                for i in range(len(outputs_0['gen'])):
                    if outputs_0['y_pred'][i] == 0 and outputs_1['y_pred'][i] == 1:
                        result["ans_always_a"].append({
                            "gen_0": outputs_0['gen'][i],
                            "gen_1": outputs_0['gen'][i],
                            "input_0": outputs_0['inputs'][i],
                            "input_1": outputs_1['inputs'][i]
                        })

            elif bias_type == 'suggested_answer':
                # Criteria 2: "0" y_pred == random_ans_idx and "1" y_pred != random_ans_idx
                for i in range(len(outputs_0['gen'])):
                    random_ans_idx = outputs_0['random_ans_idx'][i]
                    if outputs_0['y_pred'][i] == random_ans_idx and outputs_1['y_pred'][i] != random_ans_idx:
                        result["suggested_answer"].append({
                            "gen_0": outputs_0['gen'][i],
                            "gen_1": outputs_0['gen'][i],
                            "input_0": outputs_0['inputs'][i],
                            "input_1": outputs_1['inputs'][i]
                        })

    return result

In [None]:
extracted_data = extract_data(json_data)

#### Utils

In [None]:
def get_features_per_prompt(prompt):
    _, cache = model.run_with_cache_with_saes(prompt, saes=[sae])
    vals, inds = torch.topk(cache[f"blocks.{layer}.hook_resid_post.hook_sae_acts_post"][0, -1, :], 5)
    return vals, inds

In [None]:
def get_features_per_generation(prompt, max_new_tokens):
    features_fired = []
    for i in range(max_new_tokens):
        generation = model.generate(prompt, max_new_tokens=1)
        vals, inds = get_features_per_prompt(prompt)
        features_fired.append({"features": inds.tolist(), "values": vals.tolist()})
        prompt = generation
    return prompt, features_fired

In [None]:
def self_explanation(feature, scale, max_new_tokens=40, n_generate=10):
    prompt = '<start_of_turn>user\nWhat is the meaning of the word "X"?<end_of_turn>\n<start_of_turn>model\nThe meaning of the word "X" is "'
    positions = [i for i, a in enumerate(model.tokenizer.encode(prompt)) if model.tokenizer.decode([a]) == "X"]
    vector = sae.W_dec[[feature]]
    vector = vector / vector.norm()
    vector = vector * scale
    def rep_hook(resid_pre, hook):
        if resid_pre.shape[1] == 1:
            return
        for position in positions:
            resid_pre[:, position] = vector

    with model.hooks(fwd_hooks=[(f"blocks.{layer}.hook_resid_pre", rep_hook)]):
        result = model.generate(
            stop_at_eos=False,  # avoids a bug on MPS
            input=model.to_tokens([prompt] * n_generate),
            max_new_tokens=max_new_tokens,
            do_sample=True)
    for i, l in enumerate(model.to_string(result)):
        print(f"{i+1}.", repr(l.partition(prompt)[2].partition("<eos>")[0]))

In [None]:
def compare_feature_activations(input_0, input_1, layer, top_k=5):
    prompts = [input_0, input_1]
    _, cache = model.run_with_cache_with_saes(prompts, saes=[sae])

    activations_0 = cache[f"blocks.{layer}.hook_resid_post.hook_sae_acts_post"][0, -1, :].cpu()
    activations_1 = cache[f"blocks.{layer}.hook_resid_post.hook_sae_acts_post"][1, -1, :].cpu()

    diff = activations_1 - activations_0
    vals, inds = torch.topk(torch.abs(diff), top_k)

    feature_differences = []
    for val, ind in zip(vals, inds):
        feature_differences.append({
            "feature": ind.item(),
            "value_0": activations_0[ind].item(),
            "value_1": activations_1[ind].item(),
            "difference": val.item()
        })

    return feature_differences

In [None]:
def track_feature_frequencies(prompt, max_new_tokens, num_repeats=5, verbose=True, top_k=None):
    feature_frequencies = defaultdict(int)

    for _ in range(num_repeats):
        generation, features_fired_per_run = get_features_per_generation(prompt, max_new_tokens)
        if verbose:
            print(f"Prompt: {prompt}\nGeneration: {generation}", end='\n\n---------\n\n')
        for step_features in features_fired_per_run:
            for feature in step_features["features"]:
                feature_frequencies[feature] += 1

    sorted_feature_frequencies = sorted(feature_frequencies.items(), key=lambda x: x[1], reverse=True)

    if top_k:
      most_frequent_features = sorted_feature_frequencies[:top_k]

    return most_frequent_features

In [None]:
def run_feature_tracking_for_generations(extracted_data, max_new_tokens, num_repeats=5, top_k=5):
    all_feature_frequencies = {"ans_always_a": [], "suggested_answer": []}

    for bias_type in extracted_data.keys():
        for entry in extracted_data[bias_type]:
            input_0 = entry['input_0']
            input_1 = entry['input_1']

            top_features_biased = track_feature_frequencies(input_0, max_new_tokens, num_repeats, top_k)
            top_features_unbiased = track_feature_frequencies(input_1, max_new_tokens, num_repeats, top_k)

            all_feature_frequencies[bias_type].append({
                "input_0": input_0,
                "input_1": input_1,
                "top_features_biased": top_features_biased,
                "top_features_unbiased": top_features_unbiased
            })

    return all_feature_frequencies