In [3]:
from datasets import load_dataset  
from transformer_lens import HookedTransformer
from sae_lens import SAE
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import einops
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
device = "cuda:0" if torch.cuda.is_available() else "cpu"
print(f'{device}')
model = HookedTransformer.from_pretrained('gemma-2b', device=device)
model.to(device)
print('DonE! Happy InterpreTING!!')

`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.
Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use
`config.hidden_activation` if you want to override this behaviour.
See https://github.com/huggingface/transformers/pull/29402 for more details.


cuda:0


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



Loaded pretrained model gemma-2b into HookedTransformer
Moving model to device:  cuda:0
DonE! Happy InterpreTING!!


In [5]:
from transformer_lens.utils import tokenize_and_concatenate

dataset = load_dataset(
    path = "NeelNanda/pile-10k",
    split="train",
    streaming=False,
)

In [6]:
dataset['text'][45]

'This application is based upon and claims the benefit of priority from the prior Japanese Patent Application No. 2000-159163, filed Mar. 31, 2000, the entire contents of which are incorporated herein by reference.\nThe present invention relates to a method of forming a composite member, in which a conductive portion is formed in an insulator, the composite member being used in, for example, a wiring board in the fields of electric appliances, electronic appliances and electric and electronic communication. The present invention also relates to a photosensitive composition and an insulating material that can be suitably used in the manufacturing method of the composite member. Further, the present invention relates to a composite member manufactured by the manufacturing method of the present invention and to a multi-layer wiring board and an electronic package including the particular composite member.\nIn recent years, increase in the degree of integration and miniaturization of vario

In [7]:
import tqdm

In [8]:
batch_str = []
for i in tqdm.trange(len(dataset['text'][:500])):
    batch_str.append(dataset['text'][i][:128])

name = 'blocks.3.hook_resid_post'

  0%|          | 0/500 [00:00<?, ?it/s]

100%|██████████| 500/500 [00:18<00:00, 27.65it/s]


In [9]:
batch_str[423]

'Cast metal bases as an economical alternative for the severely resorbed mandible.\nResorption of the alveolar ridge is a common p'

In [10]:
def activation_filter(name: str) -> bool:
    return name.endswith('blocks.3.hook_resid_post')

In [11]:
from sae_lens import SAE

sae, cfg_dict, sparsity = SAE.from_pretrained(
    release = 'gemma-2b-res-jb',
    sae_id = 'blocks.6.hook_resid_post',
    device = device
)

In [12]:
sae.W_dec.shape

torch.Size([16384, 2048])

In [13]:
num_features = 10
feature_dim = model.cfg.d_model
# features = torch.randn(num_features, feature_dim).to(device)
# take features from rows of the decoder
features = sae.W_dec[:num_features].to(device)
features = F.normalize(features, p=2, dim=-1)

In [14]:
features.shape

torch.Size([10, 2048])

In [12]:
def activation_turn_off(activations, hook, idx=1):
    global features
    direction = features[idx]
    activations = activations.clone()
    component = activations @ direction
    activations = activations - component[:, :, None] * direction[None, None, :]
    return activations

In [13]:
def entropy(probs):
    eps = 1e-8
    return -torch.sum(probs * torch.log(probs + eps), dim=-1)

In [14]:
torch.cuda.empty_cache()

In [15]:
batch_size = 8
information_gains = [0 for _ in range(num_features)]

for feature_idx in tqdm.trange(num_features):
    for i in tqdm.trange(0, len(batch_str), batch_size):
        str_batch = batch_str[i:i + batch_size]
        model.reset_hooks()
        logits = model.run_with_hooks(str_batch, return_type='logits')
        probs = F.softmax(logits, dim=-1)[:, :, :]
        entropies_before = entropy(probs)
        torch.cuda.empty_cache()
        logits_after = model.run_with_hooks(str_batch, return_type='logits', fwd_hooks=[(
            activation_filter,
            lambda x, hook: activation_turn_off(x, hook, idx=feature_idx)
            )])
        probs_after = F.softmax(logits_after, dim=-1)[:, :, :]
        entropies_after = entropy(probs_after)
        torch.cuda.empty_cache()
        diff_entropies = entropies_before - entropies_after
        feature_importance = - diff_entropies.sum()
        information_gains[feature_idx] += feature_importance.item()
        torch.cuda.empty_cache()

100%|██████████| 63/63 [00:26<00:00,  2.34it/s]
100%|██████████| 63/63 [00:25<00:00,  2.51it/s]
100%|██████████| 63/63 [00:25<00:00,  2.49it/s]
100%|██████████| 63/63 [00:25<00:00,  2.52it/s]
100%|██████████| 63/63 [00:25<00:00,  2.49it/s]
100%|██████████| 63/63 [00:25<00:00,  2.50it/s]
100%|██████████| 63/63 [00:25<00:00,  2.52it/s]
100%|██████████| 63/63 [00:25<00:00,  2.48it/s]
100%|██████████| 63/63 [00:25<00:00,  2.46it/s]
100%|██████████| 63/63 [00:25<00:00,  2.47it/s]
100%|██████████| 10/10 [04:14<00:00, 25.47s/it]


In [22]:
# normalize information gains
information_gains_ = [ig / len(batch_str) for ig in information_gains]

In [24]:
# store them in a txt file
with open('information_gains.txt', 'w') as f:
    for ig in information_gains_:
        f.write(f'{ig}\n')

In [1]:
# load them up
information_gains_ = []
with open('information_gains.txt', 'r') as f:
    for line in f:
        information_gains_.append(float(line))

In [23]:
fig = go.Figure()
fig.add_trace(go.Bar(x=[f'#{i}' for i in range(num_features)], y=information_gains_, marker_color='DarkRed'))
fig.update_layout(title='Information Gain from first 10 SAE features')
fig.update_layout(plot_bgcolor='white')
fig.update_xaxes(title_text='Feature #')
fig.update_yaxes(title_text='Information Gain')
fig.update_layout(width=800, height=500)
# y-axis grid lines
fig.update_yaxes(showgrid=True, gridwidth=2, gridcolor='DarkGray')
# show zero line
fig.update_yaxes(zeroline=True, zerolinewidth=2, zerolinecolor='Black')
fig.show()

In [29]:
fig = go.Figure()
fig.add_trace(go.Bar(x=[f'#{i}' for i in range(1000)], y=sorted(information_gains_, reverse=True), marker_color='DarkRed', width=2))
fig.update_layout(title='Information Gain from first 1000 SAE features')
fig.update_layout(plot_bgcolor='white')
fig.update_xaxes(title_text='Feature #')
fig.update_yaxes(title_text='Information Gain')
fig.update_layout(width=800, height=500)
# y-axis grid lines
fig.update_yaxes(showgrid=True, gridwidth=2, gridcolor='DarkGray')
# show zero line
fig.update_yaxes(zeroline=True, zerolinewidth=2, zerolinecolor='Black')
fig.show()

In [17]:
top_100_features = sorted(range(len(information_gains_)), key=lambda i: information_gains_[i], reverse=True)[:100]

In [18]:
fig = go.Figure()
fig.add_trace(go.Bar(x=[f'#{i}' for i in top_100_features], y=[information_gains_[i] for i in top_100_features], marker_color='DarkRed'))
fig.update_layout(title='Information Gain from top 100 SAE features')
fig.update_layout(plot_bgcolor='white')
fig.update_xaxes(title_text='Feature #')
fig.update_yaxes(title_text='Information Gain')
fig.update_layout(width=800, height=500)
# y-axis grid lines
fig.update_yaxes(showgrid=True, gridwidth=2, gridcolor='DarkGray')
# show zero line
fig.update_yaxes(zeroline=True, zerolinewidth=2, zerolinecolor='Black')
fig.show()