## Setup

In [1]:
IN_COLAB = False
print("Running as a Jupyter notebook - intended for development only!")
from IPython import get_ipython

ipython = get_ipython()
# Code to automatically update the EasyTransformer code as its edited without restarting the kernel
ipython.magic("load_ext autoreload")
ipython.magic("autoreload 2")

# Import stuff
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import einops
from fancy_einsum import einsum
import tqdm.notebook as tqdm
import random
from pathlib import Path
import plotly.express as px
from torch.utils.data import DataLoader

from torchtyping import TensorType as TT
from typing import List, Union, Optional, Tuple
from functools import partial
import copy

import itertools
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
import datasets
from IPython.display import HTML
from torchtyping import TensorType

import circuitsvis
from circuitsvis import attention

import easy_transformer
import easy_transformer.utils as utils
from easy_transformer.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from easy_transformer import EasyTransformer, EasyTransformerConfig, FactoredMatrix, ActivationCache

import pandas as pd
from tqdm import tqdm

torch.set_grad_enabled(False)



Running as a Jupyter notebook - intended for development only!


  ipython.magic("load_ext autoreload")
  ipython.magic("autoreload 2")


<torch.autograd.grad_mode.set_grad_enabled at 0x7fa9e051b580>

## Generate bigrams

In [2]:
model = EasyTransformer.from_pretrained(
    "NeelNanda/SoLU_1L512W_C4_Code",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=True,
    device="cpu"
)

attn_model = EasyTransformer.from_pretrained(
    "NeelNanda/Attn_Only_1L512W_C4_Code",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=True,
    device="cpu"
    )

tokenizer = model.tokenizer

Loading model: NeelNanda/SoLU_1L512W_C4_Code
Moving model to device:  cpu
Finished loading pretrained model NeelNanda/SoLU_1L512W_C4_Code into EasyTransformer!
Loading model: NeelNanda/Attn_Only_1L512W_C4_Code
Moving model to device:  cpu
Finished loading pretrained model NeelNanda/Attn_Only_1L512W_C4_Code into EasyTransformer!


In [14]:
tokens = []
# 48261
for tokenIndex in tqdm(range(48261)):
    # Get the token
    token_idx = tokenizer.decode(tokenIndex)
    
    # Initialise the token details
    details = {
        "idx": tokenIndex,
        "token": token_idx,
    }
    
    # Get the top k logits & corresponding tokens
    logits = model(token_idx)[0][1]
    probs = F.softmax(logits, dim=-1)
    _, topKTokens = torch.topk(probs, 10, largest=True)
    
    for idx, token_idx in enumerate(topKTokens):
        token_idx = token_idx.item()
        token = tokenizer.decode(token_idx)
        prob = probs[token_idx]
        
        # Add to the details
        details[f"top_{idx}_idx"] = token_idx
        details[f"top_{idx}_token"] = token
        details[f"top_{idx}_prob"] = prob.item()
    
    # Add to tokens list
    tokens.append(details)
    
# Convert to a dataframe
bigrams = pd.DataFrame(tokens)

100%|██████████| 48261/48261 [08:31<00:00, 94.28it/s] 


In [20]:
bigrams.sort_values("top_0_prob", ascending=True).head(30)

Unnamed: 0,idx,token,top_0_idx,top_0_token,top_0_prob,top_1_idx,top_1_token,top_1_prob,top_2_idx,top_2_token,...,top_6_prob,top_7_idx,top_7_token,top_7_prob,top_8_idx,top_8_token,top_8_prob,top_9_idx,top_9_token,top_9_prob
510,510,The,1640,best,0.005995,801,first,0.005678,944,most,...,0.003201,743,new,0.003159,1689,next,0.002942,310,I,0.002678
380,380,The,801,first,0.006765,1640,best,0.005584,944,most,...,0.003247,1689,next,0.003159,374,2,0.003153,642,other,0.003138
434,434,'s,248,a,0.007026,254,the,0.006905,1193,own,...,0.005753,1464,life,0.005432,671,time,0.005009,286,and,0.004526
8821,8821,respective,2127,values,0.007806,16,.,0.007429,4203,countries,...,0.004329,671,time,0.004022,4105,parts,0.003864,3462,products,0.003846
25514,25514,_{(\,65,_,0.008869,15,-,0.007593,14,",",...,0.003889,274,of,0.003874,16,.,0.003828,311,is,0.003747
36331,36331,leurs,3418,questions,0.008921,3284,events,0.006317,3134,services,...,0.005452,8227,experiences,0.00542,6109,stories,0.005154,269,p,0.004835
9383,9383,på,276,in,0.008977,270,f,0.007449,14,",",...,0.006292,282,to,0.005844,1798,den,0.005723,2586,site,0.005525
12732,12732,în,270,f,0.009644,1188,tr,0.009373,248,a,...,0.006734,274,of,0.006496,69,c,0.006015,338,1,0.005642
47707,47707,"""--",455,all,0.009845,2523,time,0.009165,3933,from,...,0.007008,12712,help,0.006571,4,"""",0.006069,12767,debug,0.005425
254,254,the,801,first,0.009971,1640,best,0.009654,1499,world,...,0.005482,1026,way,0.005396,1689,next,0.004245,743,new,0.004167


In [22]:
bigrams["top_0_prob"].describe()

count    48261.000000
mean         0.193992
std          0.163758
min          0.005995
25%          0.091559
50%          0.143002
75%          0.227320
max          0.999555
Name: top_0_prob, dtype: float64

In [21]:
# Save to CSV
bigrams.to_csv("bigrams.csv", index=False)

### Generate prompts to test

In [37]:

prompts = []
answers_correct = []
answers_incorrect = []

i = 0

for idx, row in bigrams.sort_values("top_0_prob", ascending=True).iterrows():
    # Get the token & most likely next tokens
    token = str(row["token"])
    most_likely = str(row["top_0_token"])
    less_likely = str(row["top_1_token"]) # Can change to 1-9 
    
    # Create the prompt
    prompts.append(f"{token}{less_likely} {token}{less_likely} {token}")
    answers_correct.append(less_likely)
    answers_incorrect.append(most_likely)  
    
    i += 1
    if i >= 100:
        break  

In [38]:
df = pd.DataFrame([prompts, answers_correct, answers_incorrect]).T

# Set all column types as strings
df = df.astype(str)
df.columns = ["prompt", "answer_correct", "answer_incorrect"]

# Count any answer_correct that are not strings
not_strings = df["answer_correct"].apply(lambda x: not isinstance(x, str))

df.to_csv("bigram_prompts.csv", index=False)

## Attn vs SoLU

In [41]:
diffs = []
solu_res = []
attn_res = []

for idx, prompt in tqdm(enumerate(prompts)):
    
    answer = answers_correct[idx]
    answer_token = model.to_single_token(answer[0])
    
    solu_logits = model(prompt)
    attn_logits = attn_model(prompt)
    
    solu_last_logits = solu_logits[0][1]
    attn_last_logits = attn_logits[0][1]
    
    solu_probs = F.softmax(solu_last_logits, dim=-1)
    attn_probs = F.softmax(attn_last_logits, dim=-1)
    
    solu_correct_prob = solu_probs[answer_token]
    attn_correct_prob = attn_probs[answer_token]
    
    diff = solu_correct_prob - attn_correct_prob
    
    diffs.append(diff.item())
    solu_res.append(solu_correct_prob.item())
    attn_res.append(attn_correct_prob.item())

    
diffs_series = pd.Series(diffs)
solu_res_series = pd.Series(solu_res)
attn_res_series = pd.Series(attn_res)

# Show summary statistics
diffs_series.describe(), solu_res_series.describe(), attn_res_series.describe()

100it [00:02, 46.41it/s]


(count    100.000000
 mean       0.000583
 std        0.006990
 min       -0.030716
 25%       -0.001032
 50%       -0.000029
 75%        0.003791
 max        0.014648
 dtype: float64,
 count    100.000000
 mean       0.005701
 std        0.006236
 min        0.000003
 25%        0.000266
 50%        0.001626
 75%        0.011957
 max        0.020001
 dtype: float64,
 count    100.000000
 mean       0.005118
 std        0.008472
 min        0.000003
 25%        0.000463
 50%        0.001630
 75%        0.006255
 max        0.042072
 dtype: float64)