## Setup

In [106]:
IN_COLAB = False
print("Running as a Jupyter notebook - intended for development only!")
from IPython import get_ipython

ipython = get_ipython()
# Code to automatically update the EasyTransformer code as its edited without restarting the kernel
ipython.magic("load_ext autoreload")
ipython.magic("autoreload 2")

# Import stuff
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import einops
from fancy_einsum import einsum
import tqdm.notebook as tqdm
import random
from pathlib import Path
import plotly.express as px
from torch.utils.data import DataLoader

from torchtyping import TensorType as TT
from typing import List, Union, Optional, Tuple
from functools import partial
import copy

import itertools
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
import datasets
from IPython.display import HTML
from torchtyping import TensorType

import circuitsvis
from circuitsvis import attention

import easy_transformer
import easy_transformer.utils as utils
from easy_transformer.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from easy_transformer import EasyTransformer, EasyTransformerConfig, FactoredMatrix, ActivationCache

import pandas as pd
from tqdm import tqdm

torch.set_grad_enabled(False)



Running as a Jupyter notebook - intended for development only!
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


  ipython.magic("load_ext autoreload")
  ipython.magic("autoreload 2")


<torch.autograd.grad_mode.set_grad_enabled at 0x7f50bc592280>

## Generate bigrams

In [None]:
model = EasyTransformer.from_pretrained(
    "NeelNanda/SoLU_1L512W_C4_Code",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=True,
    device="cpu"
)

tokenizer = model.tokenizer

In [89]:
tokens = []
# 48261
for tokenIndex in tqdm(range(48261)):
    # Get the token
    token_idx = tokenizer.decode(tokenIndex)
    
    # Initialise the token details
    details = {
        "idx": tokenIndex,
        "token": token_idx,
    }
    
    # Get the top k logits & corresponding tokens
    logits = model(token_idx)[0][1] * -1
    log_probs = F.log_softmax(logits, dim=-1)
    _, topKTokens = torch.topk(log_probs, 10, largest=False)
    
    for idx, token_idx in enumerate(topKTokens):
        token_idx = token_idx.item()
        token = tokenizer.decode(token_idx)
        log_prob = log_probs[token_idx]
        
        # Add to the details
        details[f"top_{idx}_idx"] = token_idx
        details[f"top_{idx}_token"] = token
        details[f"top_{idx}_log_prob"] = log_prob.item()
        details[f"top_{idx}_prob"] = torch.exp(log_prob).item()
    
    # Add to tokens list
    tokens.append(details)
    
# Convert to a dataframe
bigrams = pd.DataFrame(tokens)

100%|██████████| 48261/48261 [09:04<00:00, 88.62it/s] 


In [94]:
bigrams.sort_values("top_0_log_prob", ascending=True).head(30)

Unnamed: 0,idx,token,top_0_idx,top_0_token,top_0_log_prob,top_0_prob,top_1_idx,top_1_token,top_1_log_prob,top_1_prob,...,top_7_log_prob,top_7_prob,top_8_idx,top_8_token,top_8_log_prob,top_8_prob,top_9_idx,top_9_token,top_9_log_prob,top_9_prob
46680,46680,doxor,537,ub,-49.918171,2.093215e-22,892,io,-45.722546,1.389795e-20,...,-44.249767,6.061362e-20,386,ant,-44.237835,6.134122e-20,336,ul,-44.195564,6.398974e-20
39831,39831,)[$,65,_,-44.780643,3.5646229999999996e-20,16,.,-44.176205,6.524062e-20,...,-43.102398,1.9092659999999998e-19,21,3,-43.045776,2.0204909999999998e-19,10,(,-43.040642,2.0308919999999999e-19
42706,42706,django,16,.,-43.669113,1.0832909999999999e-19,65,_,-39.871948,4.82873e-18,...,-34.897434,6.986133e-16,19191,jango,-34.142761,1.485892e-15,188,\n,-34.006069,1.703538e-15
39769,39769,googleapis,16,.,-42.788132,2.6142699999999997e-19,65,_,-36.27253,1.766203e-16,...,-34.406872,1.140999e-15,17,/,-34.26812,1.310825e-15,3,!,-33.734135,2.235895e-15
17857,17857,pcbi,16,.,-42.189243,4.758218999999999e-19,533,which,-32.767799,5.876592e-15,...,-30.74773,4.430272e-14,328,on,-30.586317,5.206325e-14,276,in,-30.569227,5.296065e-14
29817,29817,CTYPE,10,(,-41.39959,1.048061e-18,65,_,-38.195564,2.5815300000000002e-17,...,-35.822189,2.7709e-16,16,.,-35.570183,3.565052e-16,4862,[',-35.508781,3.790811e-16
44326,44326,Jagu,275,ar,-40.845989,1.823108e-18,1020,ars,-38.983204,1.1743830000000001e-17,...,-30.836248,4.054968e-14,268,al,-30.702444,4.635515e-14,15,-,-30.653946,4.865869e-14
28707,28707,fasterxml,16,.,-40.784729,1.938283e-18,28,:,-32.894466,5.177436e-15,...,-29.990166,9.450103e-14,15,-,-29.970562,9.637188e-14,348,as,-29.962727,9.712995e-14
27444,27444,intimid,834,ating,-40.640915,2.2380760000000002e-18,366,ate,-38.973755,1.1855320000000001e-17,...,-33.8522,1.986902e-15,796,ative,-32.374767,8.705971e-15,2395,atory,-32.202457,1.034309e-14
13297,13297,$^{-,19,1,-40.242626,3.3331090000000002e-18,20,2,-38.650776,1.6375e-17,...,-33.024708,4.545186e-15,3857,length,-32.658726,6.553837e-15,75,i,-32.17202,1.066275e-14


In [97]:
b = pd.read_csv("bigrams.csv")
b.head()

  b = pd.read_csv("bigrams.csv")


Unnamed: 0,idx,token,top_0_idx,top_0_token,top_0_log_prob,top_0_prob,top_1_idx,top_1_token,top_1_log_prob,top_1_prob,...,top_7_log_prob,top_7_prob,top_8_idx,top_8_token,top_8_log_prob,top_8_prob,top_9_idx,top_9_token,top_9_log_prob,top_9_prob
0,0.0,<|EOS|>,510,The,-28.08642578125,0.0,1516,This,-27.121768951416016,0.0,...,-26.040726,0.0,1253,What,-25.94207,0.0,19,1,-25.915298,5.560693e-12
1,1.0,<|BOS|>,16,.,-24.693307876586918,0.0,282,to,-24.32893943786621,0.0,...,-23.646439,0.0,276,in,-23.592276,0.0,274,of,-23.579889,5.746246e-11
2,2.0,<|PAD|>,16,.,-23.047046661376957,0.0,282,to,-22.706945419311523,0.0,...,-21.823261,0.0,324,for,-21.589272,0.0,248,a,-21.55468,4.354331e-10
3,3.0,!,188,\n,-28.494014739990234,0.0,0,<|EOS|>,-27.0388240814209,0.0,...,-25.018461,0.0,731,It,-24.92601,0.0,496,In,-24.387125,2.56334e-11
4,4.0,"""",188,\n,-24.71556854248047,0.0,477,\n,-23.75620651245117,0.0,...,-22.65057,0.0,311,is,-22.650215,0.0,1465,\n\n,-22.582977,1.557175e-10


In [111]:
# List of prompts (as a list of token indices)
prompts: List[List[int]] = []
# List of answers, in the format (correct, incorrect)
answers: List[Tuple[int, int]] = []

for idx, row in bigrams.iterrows():
    tokenIdx = int(row["idx"])
    most_likely = int(row["top_0_idx"])
    second_likely = int(row["top_1_idx"])
    prompt = [tokenIdx, second_likely, 510, tokenIdx]
    answer = (second_likely, most_likely)
    prompts.append(prompt)
    answers.append(answer)
    
prompts = torch.tensor(prompts)
answers = torch.tensor(answers)

prompts.shape, answers.shape

(torch.Size([48261, 4]), torch.Size([48261, 2]))

In [110]:
prompts.shape, answers.shape

(torch.Size([48261, 4]), torch.Size([48261, 2]))