# 0. Imports

In [2]:
import numpy as np
import torch as t

from feature_selection import get_thres_features, get_topk_features
from feature_vis import transform_list, get_neuronpedia_quicklist
from load_models import ModelLoader

from feature_circuits.loading_utils import load_examples
from feature_circuits.attribution import patching_effect

DEVICE = t.device("cuda" if t.cuda.is_available() else "cpu")

  from .autonotebook import tqdm as notebook_tqdm


# 1. Model and SAE

In [3]:
# Specify model and dictionary path
model_name = "EleutherAI/pythia-70m-deduped"
dictionary_path = "feature-circuits/dictionary_learning" #ADAPT

# Initialize the loader
loader = ModelLoader(model_name=model_name, dictionary_path=dictionary_path)

# Load the model, tokenizer, dictionaries, submodules, and names
model, tokenizer, dictionaries, submodules, submodule_names = loader.load_model()

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


# 2. Datasets

## 2.1 Stereotypical Professions Dataset

Example: 
- male_prefixes = "The doctor runs because", male_answer = "he"
- female_prefixes = "The nurse runs because", female_answer = "she"

In [4]:
num_examples = 50
length = 4

dataset_stereo = 'data/prof_stereo.json'

examples_stereo = load_examples(dataset_stereo, num_examples, model, length=length)
m_inputs = t.cat([e['clean_prefix'] for e in examples_stereo], dim=0).to(DEVICE)
f_inputs = t.cat([e['patch_prefix'] for e in examples_stereo], dim=0).to(DEVICE)
m_answer_idxs = t.tensor([e['clean_answer'] for e in examples_stereo], dtype=t.long, device=DEVICE)
f_answer_idxs = t.tensor([e['patch_answer'] for e in examples_stereo], dtype=t.long, device=DEVICE)

## 2.2 Anti-Stereotypical Professions Dataset

Example: 
- male_prefixes = "The nurse runs because", male_answer = "he"
- female_prefixes = "The doctor runs because", female_answer = "she"

In [5]:
num_examples = 50
length = 4

dataset_anti = 'data/prof_anti.json'

examples_anti = load_examples(dataset_anti, num_examples, model, length=length)
anti_m_inputs = t.cat([e['clean_prefix'] for e in examples_anti], dim=0).to(DEVICE)
anti_f_inputs = t.cat([e['patch_prefix'] for e in examples_anti], dim=0).to(DEVICE)
anti_m_answer_idxs = t.tensor([e['clean_answer'] for e in examples_anti], dtype=t.long, device=DEVICE)
anti_f_answer_idxs = t.tensor([e['patch_answer'] for e in examples_anti], dtype=t.long, device=DEVICE)

# 3. Attribution patching

In [6]:
# Define metric - logit diff
def metric_fn(model, clean_answer_idxs, patch_answer_idxs):
    return (
        t.gather(model.embed_out.output[:,-1,:], dim=-1, index=patch_answer_idxs.view(-1, 1)).squeeze(-1) - \
        t.gather(model.embed_out.output[:,-1,:], dim=-1, index=clean_answer_idxs.view(-1, 1)).squeeze(-1)
    )


In [7]:
# clean_logit_diff
def run_clean_logit_diff(clean_prefixes, clean_answers, patch_answers):
    with t.no_grad(), model.trace(clean_prefixes):
        metric_values = metric_fn(model, clean_answers, patch_answers).save()
    return metric_values

## 3.1 Stereotypical Professions

In [9]:
clean_stereo = run_clean_logit_diff(clean_prefixes=m_inputs, clean_answers=m_answer_idxs, patch_answers=f_answer_idxs)
print(clean_stereo.mean())

tensor(-2.9169, device='cuda:0')


In [10]:
stereo_effects = patching_effect(
        m_inputs,
        f_inputs,
        model,
        submodules,
        dictionaries,
        metric_fn = metric_fn,
        method = "ig",
        metric_kwargs={'clean_answer_idxs': m_answer_idxs, 'patch_answer_idxs': f_answer_idxs},
)

In [11]:
_, _, _, s_total = stereo_effects
print(s_total)
print('Mean Total Effect:', s_total.mean())

tensor([ 6.0336,  2.0667,  5.2272,  2.7300,  4.4751,  0.5048,  2.7223,  2.6792,
         5.6124, -0.1039,  2.1940,  7.1046,  4.6843,  4.6309,  3.4276,  5.5002,
         3.7698,  2.7712, -0.4904,  0.7440,  2.2174,  2.7968,  0.6442,  8.2258,
         2.6172,  4.8722,  3.9769,  2.7792,  5.2709,  3.3257,  2.5522,  2.7352,
         4.8757,  4.0858,  5.5792,  3.3850,  0.3947,  1.2618,  0.0912,  4.7417,
         3.9136,  2.1780,  6.2887,  3.2841,  3.3339,  1.1440, -0.1676,  6.0646,
         1.7838,  0.3525], device='cuda:0')
Mean Total Effect: tensor(3.2178, device='cuda:0')


## 3.2. Anti-Stereotypical Professions

In [12]:
clean_anti = run_clean_logit_diff(clean_prefixes=anti_m_inputs, clean_answers=anti_m_answer_idxs, patch_answers=anti_f_answer_idxs)
print(clean_anti.mean())

tensor(0.3009, device='cuda:0')


In [15]:
anti_effects = patching_effect(
        anti_m_inputs,
        anti_f_inputs,
        model,
        submodules,
        dictionaries,
        metric_fn = metric_fn,
        method = "ig",
        metric_kwargs={'clean_answer_idxs': anti_m_answer_idxs, 'patch_answer_idxs': anti_f_answer_idxs},
)

In [16]:
_, _, _, a_total = anti_effects
print(a_total)
print('Mean Total Effect:', a_total.mean())

tensor([-6.0336, -2.0667, -5.2272, -2.7300, -4.4751, -0.5048, -2.7223, -2.6792,
        -5.6124,  0.1039, -2.1940, -7.1046, -4.6843, -4.6309, -3.4276, -5.5002,
        -3.7698, -2.7712,  0.4904, -0.7440, -2.2174, -2.7968, -0.6442, -8.2258,
        -2.6172, -4.8722, -3.9769, -2.7792, -5.2709, -3.3257, -2.5522, -2.7352,
        -4.8757, -4.0858, -5.5792, -3.3850, -0.3947, -1.2618, -0.0912, -4.7417,
        -3.9136, -2.1780, -6.2887, -3.2841, -3.3339, -1.1440,  0.1676, -6.0646,
        -1.7838, -0.3525], device='cuda:0')
Mean Total Effect: tensor(-3.2178, device='cuda:0')


Same results as male: Shift is treated similarily

# 4. Feature Analysis and Interpretation

## 4.1 Feature Selection

### 4.1.0 Functions

In [17]:
# Mean of features per submodule (over all examples and tokens)
def get_nodes(clean_effects, clean_inputs, submodules):
    nodes = None
    running_total = 0
    with t.no_grad():
        if nodes is None:
            # 50 * mean over all 3 tokens, mean over all 50 examples of value
            nodes = {k : len(clean_inputs) * v.sum(dim=1).mean(dim=0) for k, v in clean_effects.items()}
        else: # necessary if batches of data are used
            for k, v in clean_effects.items():
                nodes[k] += len(clean_inputs) * v.sum(dim=1).mean(dim=0)
        running_total += len(clean_inputs) # 50 examples
    # Necessary for batches
    nodes = {k : v / running_total for k, v in nodes.items()}
    
    return nodes

### 4.1.1 Stereotypical Professions

In [18]:
s_nodes = get_nodes(stereo_effects.effects, m_inputs, submodules)

s_thres = get_thres_features(s_nodes, submodule_names)
print(f"Total number of features with activation score above threshold: {sum(len(inner_dict) for inner_dict in s_thres.values())}")
print(s_thres)

s_top30 = get_topk_features(s_nodes, submodule_names=submodule_names, top_n=30)
print(s_top30)

Total number of features with activation score above threshold: 40
{'attn_0': {}, 'attn_1': {}, 'attn_2': {13570: 0.1322668194770813, 27472: 0.8356579542160034}, 'attn_3': {2959: 0.83766770362854, 19128: 0.3139243423938751}, 'attn_4': {16163: 0.11511898040771484, 31101: 0.41737860441207886}, 'attn_5': {22001: 0.20258422195911407, 30031: 0.15391023457050323}, 'mlp_0': {10221: 0.20429246127605438, 14414: 0.2787262201309204, 16632: 0.2810131907463074, 17854: 0.19017906486988068}, 'mlp_1': {25018: 0.27510756254196167}, 'mlp_2': {}, 'mlp_3': {}, 'mlp_4': {}, 'mlp_5': {26689: -0.530010998249054, 27658: -0.3828776478767395}, 'resid_0': {644: 0.10349516570568085, 2493: 0.2558082044124603, 4452: 0.1213822141289711, 5196: 0.14977960288524628, 10590: 0.19143302738666534, 19026: 0.19144435226917267, 21898: 0.19873341917991638, 24096: 0.11161351203918457, 31617: 0.1020236387848854}, 'resid_1': {8194: 0.13810846209526062, 23470: 0.13365794718265533, 28555: 0.1833762675523758, 30248: 0.15729229152202

### 4.1.2 Anti-Stereotypical Professions

In [19]:
a_nodes = get_nodes(anti_effects.effects, f_inputs, submodules)

a_thres = get_thres_features(a_nodes, submodule_names)
print(f"Total number of features with activation score above threshold: {sum(len(inner_dict) for inner_dict in a_thres.values())}")
print(a_thres)

a_top30 = get_topk_features(a_nodes, submodule_names=submodule_names, top_n=30)
print(a_top30)

Total number of features with activation score above threshold: 47
{'attn_0': {}, 'attn_1': {}, 'attn_2': {13570: -0.1899036169052124, 27472: -0.5888562798500061}, 'attn_3': {2959: -0.5657263994216919, 19128: -0.4300365746021271}, 'attn_4': {16163: -0.1410037726163864, 22821: -0.19399070739746094, 31101: -0.23284615576267242}, 'attn_5': {22001: -0.1997459977865219, 30031: -0.15441685914993286}, 'mlp_0': {204: -0.10251283645629883, 10221: -0.16574573516845703, 14414: -0.22566348314285278, 15039: -0.10155228525400162, 16632: -0.3041258454322815, 17839: -0.11781653016805649, 17854: -0.1368584781885147, 23134: -0.10225982218980789, 24339: -0.11601391434669495, 28952: -0.1017894521355629}, 'mlp_1': {25018: -0.18738269805908203}, 'mlp_2': {}, 'mlp_3': {}, 'mlp_4': {}, 'mlp_5': {26689: 0.540125846862793, 27658: 0.3690444231033325}, 'resid_0': {644: -0.10307897627353668, 2493: -0.25543642044067383, 4452: -0.135439932346344, 5196: -0.13899902999401093, 10590: -0.201400026679039, 19026: -0.19565

## 4.2 Visualisation via Neuronpedia

In [20]:
s_features = transform_list(s_thres, 'pythia-70m-deduped')
get_neuronpedia_quicklist(s_features, "Gender features stereotypical professions")

Opening: https://neuronpedia.org/quick-list/?name=Gender%20features%20stereotypical%20professions&features=%5B%7B%22modelId%22%3A%20%22pythia-70m-deduped%22%2C%20%22layer%22%3A%20%222-att-sm%22%2C%20%22index%22%3A%20%2213570%22%7D%2C%20%7B%22modelId%22%3A%20%22pythia-70m-deduped%22%2C%20%22layer%22%3A%20%222-att-sm%22%2C%20%22index%22%3A%20%2227472%22%7D%2C%20%7B%22modelId%22%3A%20%22pythia-70m-deduped%22%2C%20%22layer%22%3A%20%223-att-sm%22%2C%20%22index%22%3A%20%222959%22%7D%2C%20%7B%22modelId%22%3A%20%22pythia-70m-deduped%22%2C%20%22layer%22%3A%20%223-att-sm%22%2C%20%22index%22%3A%20%2219128%22%7D%2C%20%7B%22modelId%22%3A%20%22pythia-70m-deduped%22%2C%20%22layer%22%3A%20%224-att-sm%22%2C%20%22index%22%3A%20%2216163%22%7D%2C%20%7B%22modelId%22%3A%20%22pythia-70m-deduped%22%2C%20%22layer%22%3A%20%224-att-sm%22%2C%20%22index%22%3A%20%2231101%22%7D%2C%20%7B%22modelId%22%3A%20%22pythia-70m-deduped%22%2C%20%22layer%22%3A%20%225-att-sm%22%2C%20%22index%22%3A%20%2222001%22%7D%2C%20%7B%22mod

In [21]:
a_features = transform_list(a_thres, 'pythia-70m-deduped')
get_neuronpedia_quicklist(a_features, "Gender features anti-stereotypical professions")

Opening: https://neuronpedia.org/quick-list/?name=Gender%20features%20anti-stereotypical%20professions&features=%5B%7B%22modelId%22%3A%20%22pythia-70m-deduped%22%2C%20%22layer%22%3A%20%222-att-sm%22%2C%20%22index%22%3A%20%2213570%22%7D%2C%20%7B%22modelId%22%3A%20%22pythia-70m-deduped%22%2C%20%22layer%22%3A%20%222-att-sm%22%2C%20%22index%22%3A%20%2227472%22%7D%2C%20%7B%22modelId%22%3A%20%22pythia-70m-deduped%22%2C%20%22layer%22%3A%20%223-att-sm%22%2C%20%22index%22%3A%20%222959%22%7D%2C%20%7B%22modelId%22%3A%20%22pythia-70m-deduped%22%2C%20%22layer%22%3A%20%223-att-sm%22%2C%20%22index%22%3A%20%2219128%22%7D%2C%20%7B%22modelId%22%3A%20%22pythia-70m-deduped%22%2C%20%22layer%22%3A%20%224-att-sm%22%2C%20%22index%22%3A%20%2216163%22%7D%2C%20%7B%22modelId%22%3A%20%22pythia-70m-deduped%22%2C%20%22layer%22%3A%20%224-att-sm%22%2C%20%22index%22%3A%20%2222821%22%7D%2C%20%7B%22modelId%22%3A%20%22pythia-70m-deduped%22%2C%20%22layer%22%3A%20%224-att-sm%22%2C%20%22index%22%3A%20%2231101%22%7D%2C%20%7B%