# Comparison of Transformer Models with Code dataset

Comparison of models, using a dataset of python code. The aim is to find interesting prompts where an MLP layer is required to accurately predict next tokens.


## Setup

### Imports

In [4]:
from transformers import AutoTokenizer
from IPython.display import display, clear_output
from ipywidgets import widgets
from typing import List
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from datasets import load_dataset
from tqdm import tqdm
from circuitsvis.tokens import colored_tokens
from codegen import infer
from easy_transformer import EasyTransformer
import torch
import os
from pathlib import Path
from torchtyping import TensorType

### Get code prompts (training dataset)

In [5]:
# Load code dataset
prompts_dataset = load_dataset("NeelNanda/code-tokenized", split="train")

# Convert into a nested list [prompts x tokens]
prompts_tokens: List[List[int]] = []
for prompt_tokens in tqdm(prompts_dataset):
    prompts_tokens.append(prompt_tokens["tokens"])

# Convert into a NumPy array
prompts_tokens = np.array(prompts_tokens)
prompts_tokens.shape

Using custom data configuration NeelNanda--code-tokenized-d313277bd840bb66
Found cached dataset parquet (/home/user/.cache/huggingface/datasets/NeelNanda___parquet/NeelNanda--code-tokenized-d313277bd840bb66/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)
100%|██████████| 297257/297257 [02:40<00:00, 1850.56it/s]


(297257, 1024)

### Model names

In [6]:
model_names = os.listdir("data/full_pred_log_probs/code/")
model_names.sort()
model_names

FileNotFoundError: [Errno 2] No such file or directory: 'data/full_pred_log_probs/code/'

### Disable automatic differentiation

This saves GPU resources, as we're just doing inference.

In [None]:
torch.set_grad_enabled(False)

## Model accuracy comparison

In [None]:

# Output
out = widgets.Output()

def get_model_correct_log_probs_diff(model_1_name: str, model_2_name: str, log_prob: bool = True) -> np.ndarray:
    #  Get the correct log probs
    version = "v1"
    code_directory = Path("data") / "full_pred_log_probs" / "code"
    model_1_correct_log_probs: TensorType["prompt", "tokens"] = torch.load(
        code_directory / model_1_name / version / "pred_log_probs.pth"
        ).to("cpu") * -1 # Multiply by -1, as data was negative log probs
    model_2_correct_log_probs: TensorType["prompt", "tokens"] = torch.load(
        code_directory / model_2_name / version / "pred_log_probs.pth"
        ).to("cpu") * -1
    
    # Get the differences
    if log_prob:
        diff = model_2_correct_log_probs - model_1_correct_log_probs
    else:
        diff = model_2_correct_log_probs.exp() - model_1_correct_log_probs.exp()
        
    return diff.numpy()

def plot_hist_diff_by_prompt(model_1_name: str, model_2_name: str, log_prob: bool = True) -> None:
    # Get the differences
    diff = get_model_correct_log_probs_diff(model_1_name, model_2_name, log_prob)
    
    # Create histogram
    prob_diff_by_prompt = np.mean(diff, axis=-1)
    plt.hist(prob_diff_by_prompt, bins=100)
    title_prob_type = "log probs" if log_prob else "probs"
    plt.title(f"Distribution of prompts correct {title_prob_type}, between {model_2_name} & {model_1_name}")
    
    # Show histogram
    with out:
        plt.show()
    
# Create dropdowns
metric_selector = widgets.Dropdown(
    options=[("Log probs", True), ("Probs", False)],
    description="Metric",
    value=False
)
model_1_selector = widgets.Dropdown(
    options=model_names,
    description='Model 1',
    value="attn-only-1l"
)
model_2_selector = widgets.Dropdown(
    options=model_names,
    description='Model 2',
    value="solu-1l"
)

# Handle changes
def on_change(change):
    if change["type"] == "change" and change["name"] == "value":
        with out:
            clear_output()
            plot_hist_diff_by_prompt(model_1_selector.value, model_2_selector.value, metric_selector.value)

model_1_selector.observe(on_change)
model_2_selector.observe(on_change)
metric_selector.observe(on_change)

display(model_1_selector, model_2_selector, metric_selector, out)

# default
with out:
    plot_hist_diff_by_prompt(model_1_selector.value, model_2_selector.value, metric_selector.value)

Dropdown(description='Model 1', options=('attn-only-1l', 'attn-only-2l', 'attn-only-3l', 'attn-only-4l', 'gelu…

Dropdown(description='Model 2', index=8, options=('attn-only-1l', 'attn-only-2l', 'attn-only-3l', 'attn-only-4…

Dropdown(description='Metric', index=1, options=(('Log probs', True), ('Probs', False)), value=False)

Output()

## Model prompts comparison

In [None]:

# Output
model_metric_out = widgets.Output()
prompt_vis_out = widgets.Output()

def show_prob_diff_colored_tokens(prompt_id: int, diff: np.ndarray) -> None:
    prompt_token_diff = diff[prompt_id, :]
    
    tokenizer = AutoTokenizer.from_pretrained("NeelNanda/gpt-neox-tokenizer-digits")
    token_strings = [tokenizer.decode(t) for t in prompts_tokens[prompt_id]][0:-1]
    
    return colored_tokens(
        values=prompt_token_diff.tolist(),
        tokens=token_strings,
        min_value=-1,
        max_value=1,
    )

def show_token_vis(model_1_name: str, model_2_name: str, log_prob: bool = True):
    # Get differences
    diff = get_model_correct_log_probs_diff(model_1_name, model_2_name, log_prob)
    
    # Get number of tokens in each prompt that are significantly more correct
    threshold = 1 if log_prob else 0.3
    count_tokens_above_threshold = np.sum(diff > threshold, axis=-1)
    ranked_prompts = pd.DataFrame({"count_significant": count_tokens_above_threshold})
    ranked_prompts.sort_values(by="count_significant", ascending=False, inplace=True)
    prompt_ids_ranked = ranked_prompts.index.values
    
    widgets_selector = widgets.Dropdown(
        options=[ (f"{ind}: Prompt {prompt_id}", prompt_id) for ind, prompt_id in  enumerate(prompt_ids_ranked[0:500])],
        description='Prompt ID',
    )
    
    with model_metric_out:
        display(widgets_selector)
        
    def on_change(change):
        if change['type'] == 'change' and change['name'] == 'value':
            with prompt_vis_out:
                clear_output()
                display(show_prob_diff_colored_tokens(change['new'], diff))


    widgets_selector.observe(on_change)
        
    # Default
    with prompt_vis_out:
        display(show_prob_diff_colored_tokens(prompt_ids_ranked[0], diff))
    
    
    
# Create model/metric dropdowns
metric_selector = widgets.Dropdown(
    options=[("Log probs", True), ("Probs", False)],
    description="Metric",
    value=False
)
model_1_selector = widgets.Dropdown(
    options=model_names,
    description='Model 1',
    value="attn-only-1l"
)
model_2_selector = widgets.Dropdown(
    options=model_names,
    description='Model 2',
    value="solu-1l"
)

# Handle changes
def on_change(change):
    if change["type"] == "change" and change["name"] == "value":
        with model_metric_out:
            clear_output()
            show_token_vis(model_1_selector.value, model_2_selector.value, metric_selector.value)

model_1_selector.observe(on_change)
model_2_selector.observe(on_change)
metric_selector.observe(on_change)

display(model_1_selector, model_2_selector, metric_selector, model_metric_out, prompt_vis_out)

# default
with model_metric_out:
    show_token_vis(model_1_selector.value, model_2_selector.value, metric_selector.value)

Dropdown(description='Model 1', options=('attn-only-1l', 'attn-only-2l', 'attn-only-3l', 'attn-only-4l', 'gelu…

Dropdown(description='Model 2', index=8, options=('attn-only-1l', 'attn-only-2l', 'attn-only-3l', 'attn-only-4…

Dropdown(description='Metric', index=1, options=(('Log probs', True), ('Probs', False)), value=False)

Output()

Output()

## Manual analysis

### Top prompts

In [None]:
diff = get_model_correct_log_probs_diff("attn-only-1l", "solu-1l", False)
    
# Get number of tokens in each prompt that are significantly more correct
threshold = 0.3
count_tokens_above_threshold = np.sum(diff > threshold, axis=-1)
ranked_prompts = pd.DataFrame({"count_significant": count_tokens_above_threshold})
ranked_prompts.sort_values(by="count_significant", ascending=False, inplace=True)
prompt_ids_ranked = ranked_prompts.index.values

# Show top prompts
top_prompts = prompt_ids_ranked[0:1000]
tokenizer = AutoTokenizer.from_pretrained("NeelNanda/gpt-neox-tokenizer-digits")
top_100_prompts = [tokenizer.decode(p) for p in prompts_tokens[top_prompts]]
pd.set_option('display.max_rows', 500)
pd.DataFrame({"prompt": top_100_prompts}).head(200)


Unnamed: 0,prompt
0,<|BOS|>0000000000' #...\n '0000000000000000...
1,<|BOS|> # # # # # 0 #\n# 0 # 0 0 0 0 0 # 0 0 0...
2,<|BOS|>EXPR_SUB_ASSIGN = enum.auto()\n\tEXPR_D...
3,"<|BOS|>0.118),\n(0.431,0.431,0.118),\n(0.431,0..."
4,"<|BOS|>,0.000,0.000,0.000,0.000,0.000,0.000],\..."
5,"<|BOS|>0.000,0.000,0.000,0.000,0.000,0.000,0.0..."
6,"<|BOS|>.000,0.000,0.000,0.000,0.000,0.000,0.00..."
7,"<|BOS|>0,0.000,0.000,0.000,0.000,0.000,0.000],..."
8,"<|BOS|>0,0.000,0.000,0.000,0.000,0.000,0.000],..."
9,"<|BOS|>,0.000,0.000,0.000,0.000,0.000],\n[75,6..."


### Digging Into MLP vs Attn (1 layer)

#### Notes on things that the MLP layer appears to help with

- `VAR_NAME = ` -> `enum`
- Comma at end of list item (2)
- __Single quote comma at end of list item (4) `',` or `'`__
- `),` and `'),` at end of each list item (2)
- u at start of unicode string (in list) e.g. `u'\U0001d677': ... `
- __Double space after comma in list (2) (also 5 space after close brackets `)`) (7)__
- Long list of `Text = Text.replace(....` (second `Text`)
- New line repeated sequence of var names (`SC_...`)
- `()` in repeated `f()` 
- __Weird random repeated sequences e.g. end of `TSS_TSPATTRIB` and
  `IVISCOPE_ATTR` and ` RPL_` (3) and `IV` and ` (Token` at beginning of each
  new line (2)__
- Permutations of copied tokens (eg. 165)

### MLP vs Attn Examples

##### Setup

In [8]:
tokenizer = AutoTokenizer.from_pretrained("NeelNanda/gpt-neox-tokenizer-digits")
attn = EasyTransformer.from_pretrained("NeelNanda/Attn_Only_1L512W_C4_Code")
solu = EasyTransformer.from_pretrained("NeelNanda/SoLU_1L512W_C4_Code")
clear_output()

In [9]:
def show_prompt_comparison(prompt: str):
    prompt_batch = torch.tensor(tokenizer.encode(prompt), device="cuda").unsqueeze(0)

    token_strings = [tokenizer.decode(t) for t in tokenizer.encode(prompt)]

    attn_log_probs = infer.run_batch(attn, prompt_batch).squeeze(0)
    solu_log_probs = infer.run_batch(solu, prompt_batch).squeeze(0)
    print(attn_log_probs.shape)

    diff = np.exp(solu_log_probs.detach().cpu().numpy()) - np.exp(attn_log_probs.detach().cpu().numpy())

    tokens = colored_tokens(
        values=diff.tolist(),
        tokens=token_strings,
        min_value=-1,
        max_value=1,
    )

    display(tokens)

#### Lists (new lines)

Note spaces after the new line are confusingly tokenized with the newline
character.

Based on 164 (129820)

In [10]:
# Short list
prompt = """my_arr = [
    'DDR4 SDRAM',
    'Medal of Honor: Frontline',
    'Sisters of War',
    'Batalla de Monte Tumbledown',
]"""
show_prompt_comparison(prompt)

torch.Size([42])


In [11]:
# List with no newlines
prompt = """my_arr = ['DDR4 SDRAM',  'Medal of Honor: Frontline',  'Sisters of War', 'Batalla de Monte Tumbledown', Gypsy Heart Tour',]"""
show_prompt_comparison(prompt)

torch.Size([43])


In [22]:
# List of strings
prompt = """my_arr = [
    'DDR4 SDRAM asdf as f dsdsd',
    'Medal of Honor: Frontline',
    'Sisters of War: The Second Coming',
    'Batalla de Monte Tumbledown',
    'Gypsy Heart Tour',
    'Monster in My Pocket',
    'El problema del costo social',
    'Dōjutsu',
    'Elizabeth Eichhorn',
    'Plataforma HD',
    'Jeremy Scahill',
            'Caxuxi',
    'Marbella Corella',
    'Boris Kodjoe',
    'Sisters of',
    'Carol Cleveland',
    'Joseph Morgan',
    'Aidan Alexander',
    'Sentispac',
]"""

show_prompt_comparison(prompt)

torch.Size([161])


In [None]:
# List of numbers
prompt = """my_arr = [
    0.128123,
    0.457345,
    0.231342,
    0.234234,
    0.123671,
    0.128123,
    0.457345,
    0.231342,
    0.234234444444444,
    0.123671,
    0.128123,
    0.457345,
    0.231342,
    0.234234,
    0.123671,
]"""

show_prompt_comparison(prompt)

torch.Size([165])


In [None]:
# List without indent
prompt = """my_arr = [
'DDR4 SDRAM',
'Medal of Honor: Frontline',
'Sisters of War',
'Batalla de Monte Tumbledown',
'Gypsy Heart Tour',
'Monster in My Pocket',
'El problema del costo social',
'Dōjutsu',
'Elizabeth Eichhorn',
'Plataforma HD',
'Jeremy Scahill',
'Caxuxi',
'Marbella Corella',
'Boris Kodjoe',
'Carol Cleveland',
'Joseph Morgan',
'Aidan Alexander',
'Sentispac',
]"""

show_prompt_comparison(prompt)

torch.Size([144])


In [None]:
# Mashed text middle
prompt = """my_arr = [
    'DDR4 SDRAM',
    'Medal of Honor: Frontline',
    'Sisters of War',
    'Batalla de Monte Tumbledown',
    'Gypsy Heart Tour',
    'Monster in My Pocket',
  jsahdfkahskdfjhaksjfhd jkash dfjb vas dfja sdhfas fioah
  
  ashdfhas d
  asiodfjo]apsdf[ as
  df]
    'Caxuxi',
    'Marbella Corella',
    'Boris Kodjoe',
    'Carol Cleveland',
    'Joseph Morgan',
    'Aidan Alexander',
    'Sentispac',
]"""

show_prompt_comparison(prompt)

torch.Size([149])


#### . properties (182)

In [None]:
# Base
prompt = """        .hexcodes[0xA6] = ("ldx", "zeropage")
        self.hexcodes[0xB6] = ("ldx", "zeropagey")
        .hexcodes[0xC6] = ("dec", "zeropage")
        .hexcodes[0xD6] = ("dec", "zeropagex")
        .hexcodes[0xE6] = ("inc", "zeropage")
        self.hexcodes[0xF6] = ("inc", "zeropagex")
        self.hexcodes[0x07] = ("", "")
        self.hexcodes[0x17] = ("", "")
        self.hexcodes[0x27] = ("", "")
        self.hexcodes[0x37] = ("", "")
        self.hexcodes[0x47] = ("", "")
        self.hexcodes[0x57] = ("", "")
        self.hexcodes[0x67] = ("", "")
        self.hexcodes[0x07] = ("", "")
        self.hexcodes[0x17] = ("", "")"""
 
show_prompt_comparison(prompt)

torch.Size([268])


In [None]:
# Mashed middle
prompt = prompt = """        self.hexcodes[0xA6] = ("ldx", "zeropage")
        self.hexcodes[0xB6] = ("ldx", "zeropagey")
        self.hexcodes[0xC6] = ("dec", "zeropage")
        self.hexcodes[0xD6] = ("dec", "zeropagex")
        self.hexcodes[0xE6] = ("inc", "zeropage")
        self.hexcodes[0xF6] = ("inc", "zeropagex")

        sdafasdf as dfas asfd as df
        asdfsa dfas df
        
        self.hexcodes[0x47] = ("", "")
        self.hexcodes[0x57] = ("", "")
        self.hexcodes[0x67] = ("", "")
        self.hexcodes[0x07] = ("", "")
        self.hexcodes[0x17] = ("", "")"""
 
show_prompt_comparison(prompt)

torch.Size([225])


In [None]:
# Different method
prompt = prompt = """        self.codes[0xA6] = ("ldx", "zeropage")
        self.codes[0xB6] = ("ldx", "zeropagey")
        self.codes[0xC6] = ("dec", "zeropage")
        self.codes[0xD6] = ("dec", "zeropagex")
        self.codes[0xE6] = ("inc", "zeropage")
        self.codes[0xF6] = ("inc", "zeropagex")
        self.codes[0x47] = ("", "")
        self.codes[0x57] = ("", "")
        self.codes[0x67] = ("", "")
        self.codes[0x07] = ("", "")
        self.codes[0x17] = ("", "")"""
 
show_prompt_comparison(prompt)

torch.Size([182])


In [None]:
# Without self
prompt = """        .hexcodes[0xA6] = ("ldx", "zeropage")
        .hexcodes[0xB6] = ("ldx", "zeropagey")
        .hexcodes[0xC6] = ("dec", "zeropage")
        .hexcodes[0xD6] = ("dec", "zeropagex")
        .hexcodes[0xE6] = ("inc", "zeropage")
        .hexcodes[0xF6] = ("inc", "zeropagex")
        .hexcodes[0x07] = ("", "")
        .hexcodes[0x17] = ("", "")
        .hexcodes[0x27] = ("", "")"""
 
show_prompt_comparison(prompt)

torch.Size([163])


- 