## Setup

In [1]:
IN_COLAB = False
print("Running as a Jupyter notebook - intended for development only!")
from IPython import get_ipython

ipython = get_ipython()
# Code to automatically update the EasyTransformer code as its edited without restarting the kernel
ipython.magic("load_ext autoreload")
ipython.magic("autoreload 2")

# Import stuff
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import einops
from fancy_einsum import einsum
import tqdm.notebook as tqdm
import random
from pathlib import Path
import plotly.express as px
from torch.utils.data import DataLoader

from torchtyping import TensorType as TT
from typing import List, Union, Optional, Tuple
from functools import partial
import copy

import itertools
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
import datasets
from IPython.display import HTML
from torchtyping import TensorType

import circuitsvis
from circuitsvis import attention

import easy_transformer
import easy_transformer.utils as utils
from easy_transformer.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from easy_transformer import EasyTransformer, EasyTransformerConfig, FactoredMatrix, ActivationCache

import pandas as pd
from tqdm import tqdm

torch.set_grad_enabled(False)



Running as a Jupyter notebook - intended for development only!


  ipython.magic("load_ext autoreload")
  ipython.magic("autoreload 2")


<torch.autograd.grad_mode.set_grad_enabled at 0x7f7ed755aa60>

## Generate bigrams

In [2]:
model = EasyTransformer.from_pretrained(
    "NeelNanda/SoLU_1L512W_C4_Code",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=True,
    device="cpu"
)

attn_model = EasyTransformer.from_pretrained(
    "NeelNanda/Attn_Only_1L512W_C4_Code",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=True,
    device="cpu"
    )

tokenizer = model.tokenizer

Loading model: NeelNanda/SoLU_1L512W_C4_Code
Moving model to device:  cpu
Finished loading pretrained model NeelNanda/SoLU_1L512W_C4_Code into EasyTransformer!
Loading model: NeelNanda/Attn_Only_1L512W_C4_Code
Moving model to device:  cpu
Finished loading pretrained model NeelNanda/Attn_Only_1L512W_C4_Code into EasyTransformer!


In [6]:
tokens = []
# 48261
for tokenIndex in tqdm(range(48261)):
    # Get the token
    token = tokenizer.decode(tokenIndex)
    
    # Initialise the token details
    details = {
        "idx": tokenIndex,
        "token": token,
    }
    
    # Get the top k logits & corresponding tokens
    logits = model(token)[0][1]
    probs = F.softmax(logits, dim=-1)
    _, topKTokens = torch.topk(probs, 10, largest=True)
    
    for idx, token_idx in enumerate(topKTokens):
        token_idx = token_idx.item()
        token = tokenizer.decode(token_idx)
        prob = probs[token_idx]
        
        # Add to the details
        details[f"top_{idx}_idx"] = token_idx
        details[f"top_{idx}_token"] = token
        details[f"top_{idx}_prob"] = prob.item()
    
    # Add to tokens list
    tokens.append(details)
    
# Convert to a dataframe
bigrams = pd.DataFrame(tokens)

100%|██████████| 48261/48261 [09:49<00:00, 81.83it/s] 


In [13]:
bigrams.sort_values("top_0_prob", ascending=False).head(10)

Unnamed: 0,idx,token,top_0_idx,top_0_token,top_0_prob,top_1_idx,top_1_token,top_1_prob,top_2_idx,top_2_token,...,top_6_prob,top_7_idx,top_7_token,top_7_prob,top_8_idx,top_8_token,top_8_prob,top_9_idx,top_9_token,top_9_prob
17857,17857,pcbi,16,.,0.999555,533,which,8.1e-05,593,so,...,1.3e-05,345,he,1.1e-05,328,on,9e-06,276,in,9e-06
28707,28707,fasterxml,16,.,0.999089,28,:,0.000374,65,_,...,2.4e-05,25438,jackson,2e-05,15,-,2e-05,348,as,2e-05
28893,28893,umbre,619,ll,0.998891,252,on,0.000132,14392,lla,...,3.2e-05,68,b,3e-05,2139,ley,2.8e-05,1055,led,2.8e-05
46091,46091,HepG,20,2,0.995836,21,3,0.001424,19,1,...,8.8e-05,16,.,7.9e-05,15,-,5.9e-05,26,8,5.4e-05
28110,28110,surfact,1086,ants,0.995655,250,in,0.002054,386,ant,...,7.8e-05,834,ating,7.4e-05,422,ive,4.5e-05,312,ol,4.2e-05
30790,30790,courty,2130,ards,0.994725,12715,arded,0.001346,472,ard,...,0.000114,16,.,0.000113,680,ata,7.4e-05,1172,yl,6.8e-05
21543,21543,pione,398,ers,0.994663,255,er,0.001585,15,-,...,8.9e-05,675,erv,8.9e-05,258,en,8.9e-05,276,in,7.9e-05
39769,39769,googleapis,16,.,0.994287,65,_,0.001472,15034,\.,...,0.000288,41220,-%,0.000228,17,/,0.000198,3,!,0.000116
17490,17490,fibrobl,3153,astic,0.994141,1317,cells,0.000845,10,(,...,0.0001,19987,astically,8.6e-05,16,.,7.7e-05,13687,orer,7.6e-05
5817,5817,https,1333,://,0.993413,28,:,0.003046,16,.,...,4.5e-05,3,!,4.3e-05,15,-,4.1e-05,286,and,4e-05


In [85]:
# hist = bigrams["top_0_prob"].hist(bins=100, label="Top Bigram"), bigrams["top_1_prob"].hist(bins=100), bigrams["top_2_prob"].hist(bins=100)
# hist.legend(loc="upper left")

# # Add labels
# hist.axes[0].set_xlabel("Probability")

# hist

# Get df with column names top_0_prob, top_1_prob, top_2_prob
df = bigrams[["top_0_prob", "top_1_prob", "top_2_prob"]]


# df = px.data.tips()
fig = px.histogram(df)
fig.show()

In [9]:
# Save to CSV
bigrams.to_csv("bigrams.csv", index=False)

### Generate prompts to test

In [69]:

prompts = []
answers_correct = []
answers_incorrect = []

i = 0

# Random sample of bigrams
for _idx, row in bigrams.sample(1000).iterrows():
    # Get the token & most likely next tokens
    token = str(row["token"])
    most_likely = str(row["top_0_token"])
    less_likely = str(row["top_9_token"]) # Can change to 1-9 
    
    # Create the prompt
    prompts.append(f"{token}{less_likely} {token}{less_likely} {token}")
    answers_correct.append(most_likely)
    answers_incorrect.append(most_likely)  
    
    i += 1
    if i >= 1000:
        break  

## Attn vs SoLU

In [70]:
diffs = []
solu_res = []
attn_res = []

for idx, prompt in tqdm(enumerate(prompts)):
    answer = answers_correct[idx]
    answer_token = model.to_single_token(answer)

    solu_logits = model(prompt)
    attn_logits = attn_model(prompt)
    
    solu_last_logits = solu_logits[0][-1]
    attn_last_logits = attn_logits[0][-1]
    
    # print(solu_logits.shape)
    
    solu_probs = F.softmax(solu_last_logits, dim=-1)
    attn_probs = F.softmax(attn_last_logits, dim=-1)
    
    solu_correct_prob = solu_probs[answer_token]
    attn_correct_prob = attn_probs[answer_token]
    
    diff = solu_correct_prob - attn_correct_prob
    
    diffs.append(diff.item())
    solu_res.append(solu_correct_prob.item())
    attn_res.append(attn_correct_prob.item())

    
diffs_series = pd.Series(diffs)
solu_res_series = pd.Series(solu_res)
attn_res_series = pd.Series(attn_res)

# Show summary statistics
diffs_series.describe(), solu_res_series.describe(), attn_res_series.describe()

1000it [00:20, 49.41it/s]


(count    1000.000000
 mean       -0.004473
 std         0.096006
 min        -0.748499
 25%        -0.030827
 50%        -0.000728
 75%         0.018516
 max         0.481189
 dtype: float64,
 count    1.000000e+03
 mean     9.146427e-02
 std      1.442662e-01
 min      5.319623e-07
 25%      6.296380e-03
 50%      3.787390e-02
 75%      1.133159e-01
 max      9.936337e-01
 dtype: float64,
 count    1.000000e+03
 mean     9.593732e-02
 std      1.352703e-01
 min      1.891348e-07
 25%      1.102560e-02
 50%      4.938926e-02
 75%      1.233064e-01
 max      9.606335e-01
 dtype: float64)