## Setup

In [None]:
IN_COLAB = False
print("Running as a Jupyter notebook - intended for development only!")
from IPython import get_ipython

ipython = get_ipython()
# Code to automatically update the EasyTransformer code as its edited without restarting the kernel
ipython.magic("load_ext autoreload")
ipython.magic("autoreload 2")

# Import stuff
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import einops
from fancy_einsum import einsum
import tqdm.notebook as tqdm
import random
from pathlib import Path
import plotly.express as px
from torch.utils.data import DataLoader

from torchtyping import TensorType as TT
from typing import List, Union, Optional, Tuple
from functools import partial
import copy

import itertools
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
import datasets
from IPython.display import HTML
from torchtyping import TensorType

import circuitsvis
from circuitsvis import attention

import easy_transformer
import easy_transformer.utils as utils
from easy_transformer.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from easy_transformer import EasyTransformer, EasyTransformerConfig, FactoredMatrix, ActivationCache

import pandas as pd
from tqdm import tqdm

torch.set_grad_enabled(False)



## Generate bigrams

In [None]:
model = EasyTransformer.from_pretrained(
    "NeelNanda/SoLU_1L512W_C4_Code",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=True,
    device="cpu"
)

attn_model = EasyTransformer.from_pretrained(
    "NeelNanda/Attn_Only_1L512W_C4_Code",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=True,
    device="cpu"
    )

tokenizer = model.tokenizer

In [None]:
tokens = []
# 48261
for tokenIndex in tqdm(range(48261)):
    # Get the token
    token_idx = tokenizer.decode(tokenIndex)
    
    # Initialise the token details
    details = {
        "idx": tokenIndex,
        "token": token_idx,
    }
    
    # Get the top k logits & corresponding tokens
    logits = model(token_idx)[0][1] * -1
    log_probs = F.log_softmax(logits, dim=-1)
    _, topKTokens = torch.topk(log_probs, 10, largest=False)
    
    for idx, token_idx in enumerate(topKTokens):
        token_idx = token_idx.item()
        token = tokenizer.decode(token_idx)
        log_prob = log_probs[token_idx]
        
        # Add to the details
        details[f"top_{idx}_idx"] = token_idx
        details[f"top_{idx}_token"] = token
        details[f"top_{idx}_log_prob"] = log_prob.item()
        details[f"top_{idx}_prob"] = torch.exp(log_prob).item()
    
    # Add to tokens list
    tokens.append(details)
    
# Convert to a dataframe
bigrams = pd.DataFrame(tokens)

In [None]:
bigrams.sort_values("top_0_log_prob", ascending=True).head(30)

In [None]:
# List of prompts (as a list of token indices)
prompts: List[List[int]] = []
# List of answers, in the format (correct, incorrect)
answers: List[Tuple[int, int]] = []

for idx, row in bigrams.iterrows():
    tokenIdx = int(row["idx"])
    most_likely = int(row["top_0_idx"])
    second_likely = int(row["top_1_idx"])
    prompt = [tokenIdx, second_likely, 510, tokenIdx]
    answer = (second_likely, most_likely)
    prompts.append(prompt)
    answers.append(answer)
    
prompts = torch.tensor(prompts)
answers = torch.tensor(answers)

prompts.shape, answers.shape

### With words

In [172]:

prompts = []
answers_correct = []
answers_incorrect = []

for idx, row in bigrams.iterrows():
    # Get the token & most likely next tokens
    token = str(row["token"])
    most_likely = str(row["top_0_token"])
    less_likely = str(row["top_1_token"]) # Can change to 1-9 
    
    # Skip if any have spaces
    if " " in token or " " in most_likely or " " in less_likely:
        continue
    
    # Skip if token string contains non-letters
    if not token.isalpha() or not most_likely.isalpha() or not less_likely.isalpha():
        continue
        
    # Create the prompt
    prompts.append(f".{token}{less_likely} .{token}{less_likely} .{token}")
    answers_correct.append(less_likely)
    answers_incorrect.append(most_likely)
    
    

In [173]:
df = pd.DataFrame([prompts, answers_correct, answers_incorrect]).T

# Set all column types as strings
df = df.astype(str)
df.columns = ["prompt", "answer_correct", "answer_incorrect"]

# Count any answer_correct that are not strings
not_strings = df["answer_correct"].apply(lambda x: not isinstance(x, str))

df.to_csv("bigram_prompts.csv", index=False)

## Attn vs SoLU

In [174]:
diffs = []

for idx, prompt in tqdm(enumerate(prompts)):
    answer = answers_correct[idx]
    answer_token = model.to_single_token(answer[0])
    
    solu_logits = model(prompt)
    attn_logits = attn_model(prompt)
    
    solu_last_logits = solu_logits[0][1]
    attn_last_logits = attn_logits[0][1]
    
    solu_probs = F.softmax(solu_last_logits, dim=-1)
    attn_probs = F.softmax(attn_last_logits, dim=-1)
    
    solu_correct_prob = solu_probs[answer_token]
    attn_correct_prob = attn_probs[answer_token]
    
    diff = solu_correct_prob - attn_correct_prob
    
    diffs.append(diff.item())

    
diffs_series = pd.Series(diffs)

# Show summary statistics
diffs_series.describe()

2853it [00:53, 53.37it/s]


count    2853.000000
mean        0.000044
std         0.000184
min        -0.000205
25%        -0.000023
50%         0.000003
75%         0.000052
max         0.000879
dtype: float64