In [1]:
import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache

# Load Model

In [2]:
model = HookedTransformer.from_pretrained("gelu-1l")
model("Hello world!")


Loaded pretrained model gelu-1l into HookedTransformer


tensor([[[ 6.0339, -6.4613, -6.4667,  ..., -6.4511, -6.4693, -6.4487],
         [ 2.2151, -8.4667, -8.4197,  ..., -8.4754, -8.4943, -8.4214],
         [ 4.1072, -7.1321, -7.0821,  ..., -7.1124, -7.1540, -7.1635],
         [14.2510, -7.0324, -6.9920,  ..., -6.9795, -7.0364, -6.9483]]],
       device='cuda:0', grad_fn=<AddBackward0>)

In [3]:
# look at the configuration
model.cfg

HookedTransformerConfig:
{'act_fn': 'gelu',
 'attention_dir': 'causal',
 'attn_only': False,
 'attn_types': None,
 'checkpoint_index': None,
 'checkpoint_label_type': None,
 'checkpoint_value': None,
 'd_head': 64,
 'd_mlp': 2048,
 'd_model': 512,
 'd_vocab': 48262,
 'd_vocab_out': 48262,
 'device': 'cuda',
 'eps': 1e-05,
 'final_rms': False,
 'from_checkpoint': False,
 'gated_mlp': False,
 'init_mode': 'gpt2',
 'init_weights': False,
 'initializer_range': 0.035355339059327376,
 'model_name': 'GELU_1L512W_C4_Code',
 'n_ctx': 1024,
 'n_devices': 1,
 'n_heads': 8,
 'n_layers': 1,
 'n_params': 3145728,
 'normalization_type': 'LNPre',
 'original_architecture': 'neel',
 'parallel_attn_mlp': False,
 'positional_embedding_type': 'standard',
 'rotary_dim': None,
 'scale_attn_by_inverse_layer_idx': False,
 'seed': None,
 'tokenizer_name': 'NeelNanda/gpt-neox-tokenizer-digits',
 'use_attn_result': False,
 'use_attn_scale': True,
 'use_hook_tokens': False,
 'use_local_attn': False,
 'use_split_qk

In [4]:
# the avaiable modules
dir(model)

['OV',
 'QK',
 'T_destination',
 'W_E',
 'W_E_pos',
 'W_K',
 'W_O',
 'W_Q',
 'W_U',
 'W_V',
 'W_in',
 'W_out',
 'W_pos',
 '__annotations__',
 '__call__',
 '__class__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattr__',
 '__getattribute__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__setstate__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 '_apply',
 '_backward_hooks',
 '_backward_pre_hooks',
 '_buffers',
 '_call_impl',
 '_forward_hooks',
 '_forward_hooks_with_kwargs',
 '_forward_pre_hooks',
 '_forward_pre_hooks_with_kwargs',
 '_get_backward_hooks',
 '_get_backward_pre_hooks',
 '_get_name',
 '_is_full_backward_hook',
 '_load_from_state_dict',
 '_load_state_dict_post_hooks',
 '_load_state_dict_pre_hooks',
 '_maybe_warn_non_full_backward_hook',
 '_modules',
 '_named_members',

In [5]:
d_head = model.cfg.d_head
d_mlp = model.cfg.d_mlp
d_model = model.cfg.d_model
d_vocab = model.cfg.d_vocab
n_ctx = model.cfg.n_ctx
n_heads = model.cfg.n_heads
n_layers = model.cfg.n_layers

# make a nice printout of the above values
print(f"d_head: {d_head}\nd_mlp: {d_mlp}\nd_model: {d_model}\nd_vocab: {d_vocab}\nn_ctx: {n_ctx}\nn_heads: {n_heads}\nn_layers: {n_layers}")


d_head: 64
d_mlp: 2048
d_model: 512
d_vocab: 48262
n_ctx: 1024
n_heads: 8
n_layers: 1


# Tokenizer and paren tokens

In [17]:
tokenizer = model.tokenizer
all_tokens = list(tokenizer.get_vocab().keys())

tokens_parens = [t for t in all_tokens if "(" in t or ")" in t]
tokens_open_parens = [t for t in all_tokens if "(" in t]
tokens_close_parens = [t for t in all_tokens if ")" in t]
tokens_both_parens = [t for t in all_tokens if "(" in t and ")" in t]

# make a nice print statement for how many tokens of each type we have
# include percentages
print(f"Total tokens: {len(all_tokens)}")
print(f"Tokens with parentheses: {len(tokens_parens)} ({len(tokens_parens)/len(all_tokens)*100:.2f}%)")
print(f"Tokens with open parentheses: {len(tokens_open_parens)} ({len(tokens_open_parens)/len(all_tokens)*100:.2f}%)")
print(f"Tokens with close parentheses: {len(tokens_close_parens)} ({len(tokens_close_parens)/len(all_tokens)*100:.2f}%)")
print(f"Tokens with both parentheses: {len(tokens_both_parens)} ({len(tokens_both_parens)/len(all_tokens)*100:.2f}%)")

# now let's print some example tokens from tokens_parens
# format nicely with tabs and newlines, 8 examples per line
print(f"\nSome example tokens with parentheses:")
print("\n".join(["\t".join(tokens_parens[i:i+10]) for i in range(0, 100, 10)]))


Total tokens: 48262
Tokens with parentheses: 559 (1.16%)
Tokens with open parentheses: 266 (0.55%)
Tokens with close parentheses: 354 (0.73%)
Tokens with both parentheses: 61 (0.13%)

Some example tokens with parentheses:
");	Ġ:)	{})	)}}	([]	()`	}({\	})$$	'(	))?
*(	)/(	)&	)_\	')->	~),	Ġ}).	)}_	)))	().
)-\	^).	({{\	})}	))).	'));	(+)	Ġ'')	"})	(@"
__(	Ġ((	)$.	()</	?)	Ġ(),	Ġ(,	"(	(%	())
**](#	}^{(	Ġ(-	)\|_{	Ġ()]{}	++)	("#	()">	)^	?).
!)	=$(	)*-	Ġ(_	("\	^{(	})$,	^)	Ġ\(	)\[
Ġ(Â±	)}\	)|^	("%	)}$$	)}^	Ġ$$(	))**(	^*(	(?
),(	('	('',	*),	)$;	))\	}_{(	Ġ(__	[\*](#	^(
():	<>();	Ġ(Â§	("/	']);	\](	(".	\])]{}	){#	}({
^{(\	)}{\	(-	)},	Ġ*(	)+\	Ġ(	Ġ});	)[@	}))


# Load and Clean Dataset

In [162]:
# Import the load_dataset function from the Hugging Face's datasets library
from datasets import load_dataset
import tqdm

# Load the "NeelNanda/code-10k" dataset and take the training split
code_data = load_dataset("NeelNanda/code-10k", split="train")

# print the first example
print(code_data[0]['text'])




# Copyright (c) 2019 J. Alvarez-Jarreta and C.J. Brasher
#
# This file is part of the LipidFinder software tool and governed by the
# 'MIT License'. Please see the LICENSE file that should have been
# included as part of this software.
"""Graphical User Interface (GUI) to manage the parameters' collection.
"""

from collections import OrderedDict
import os

from IPython.display import display
from ipywidgets import widgets, Layout
import pandas

from LipidFinder.Configuration import LFParameters
from LipidFinder._utils import normalise_path


class _TaggedToggleButton(widgets.ToggleButton):
    """Add "tag" attribute to widgets.ToggleButton class."""

    def __init__(self, tag, **kwargs):
        widgets.ToggleButton.__init__(self, **kwargs)
        self.tag = tag


class _TaggedCheckbox(widgets.Checkbox):
    """Add "tag" attribute to widgets.Checkbox class."""

    def __init__(self, tag, **kwargs):
        widgets.Checkbox.__init__(self, **kwargs)
        self.tag = tag


class _Ta

In [163]:
# lets count the tokens in the dataset
# Initialize a counter
from collections import Counter

counter = Counter()

# Iterate over the dataset and tokenize each entry
for entry in tqdm.tqdm(code_data):
    # Tokenize the code
    tokens = tokenizer.tokenize(entry["text"])
    
    # Update the counter
    counter.update(tokens)



100%|██████████| 10000/10000 [01:25<00:00, 116.63it/s]


In [164]:
# Use plotly to plot the distribution of tokens for the 
# top 20 tokens in the dataset, do percentages
import plotly.express as px
import pandas as pd


# Get the top 20 tokens
top_20_tokens = counter.most_common(20)
percentages = [count/sum(counter.values())*100 for _, count in top_20_tokens]

# Create a dataframe
df = pd.DataFrame({"token": [token for token, _ in top_20_tokens], "percentage": percentages})

# Plot the dataframe
fig = px.bar(df, x="token", y="percentage", title="Top 20 tokens in the dataset")
fig.show()

In [165]:
# now lets look at the distribution conditioned on the presence of parentheses
# do this by subsellecting the counter

# subsellect the counter for tokens with parentheses, token_parens
counter_parens = Counter({token: count for token, count in counter.items() if token in tokens_parens})

# plot the top 20 tokens with parentheses and use percentages
top_20_tokens_parens = counter_parens.most_common(20)
percentages_parens = [count/sum(counter_parens.values())*100 for _, count in top_20_tokens_parens]

# Create a dataframe
df_parens = pd.DataFrame({"token": [token for token, _ in top_20_tokens_parens], "percentage": percentages_parens})

# Plot the dataframe
fig = px.bar(df_parens, x="token", y="percentage", title="Top 20 tokens with parentheses in the dataset")
fig.show()


In [166]:
# make  a list of the tokens with ) in them that are top20, make sure they dont have (
top_close_parens = [token for token, _ in top_20_tokens_parens if "(" not in token]
print(top_close_parens)

# get their vocab indices
top_close_parens_idx = [tokenizer.get_vocab()[token] for token in top_close_parens]
print(top_close_parens_idx)

[')', '):', "')", '),', '))', '")', '])', 'Ġ)', ').', "'),"]
[11, 2192, 3291, 581, 1210, 2719, 3185, 2312, 481, 11162]


# Make a clean dataset

In [31]:
def chunk_code(code):
    lines = code.split('\n')
    chunks = []
    chunk = ''
    paren_count = 0
    
    for line in lines:
        paren_count += line.count('(') - line.count(')')
        chunk += line + '\n'
        # If all parentheses are closed and the line had parentheses
        if paren_count == 0 and ('(' in line or ')' in line):
            chunks.append(chunk)
            chunk = ''
    
    # After going through all lines, if there's still some chunk left, append it.
    if chunk:
        chunks.append(chunk)

    return chunks

def combine_chunks(chunks):
    combined = []
    combine_pattern = [1, 2, 3]
    
    while chunks:
        for combine_count in combine_pattern:
            if len(chunks) < combine_count:
                # If not enough chunks remain, append remaining chunks and end loop
                combined.append(''.join(chunks))
                chunks = []
                break
            else:
                # Combine chunks according to pattern and remove them from chunk list
                combined.append(''.join(chunks[:combine_count]))
                chunks = chunks[combine_count:]
    
    return combined

In [34]:
clean_tokens = ["(", ")"]
# dirty tokens are all tokens that have parens other than clean tokens
dirty_tokens = [token for token in tokens_parens if token not in clean_tokens]

chunked_data = []

# go through all the data
pbar = tqdm.notebook.tqdm(total=len(code_data))
for ind, entry in tqdm.notebook.tqdm(enumerate(code_data)):

    # get the text for this entry
    text = entry["text"]

    # chunk the text
    chunks = chunk_code(text)

    # combine the chunks
    combined_chunks = combine_chunks(chunks)

    # go through the combined chunks and tokenize
    for chunk in combined_chunks:
        # tokenize the chunk
        tokens = tokenizer.tokenize(chunk)

        # get the number of clean tokens
        num_clean = sum([token in clean_tokens for token in tokens])

        # get the number of dirty tokens
        num_dirty = sum([token in dirty_tokens for token in tokens])

        if num_clean > 0 and num_dirty == 0:
            chunked_data.append(chunk)


print(f"Number of entries in chunked_data: {len(chunked_data)}")

  0%|          | 0/10000 [00:00<?, ?it/s]

0it [00:00, ?it/s]

Number of entries in chunked_data: 41037


In [35]:
# now lets go through chunked_data, tokenize it, and count the tokens that contain ( or )

# Initialize a counter
counter = Counter()

# Iterate over the dataset and tokenize each entry
for entry in tqdm.tqdm(chunked_data):
    # Tokenize the code
    tokens = tokenizer.tokenize(entry)
    
    # Update the counter
    counter.update(tokens)

# Use plotly to plot the distribution of tokens for the
# top 20 tokens in the dataset, do percentages
import plotly.express as px
import pandas as pd

# Get the top 20 tokens
top_20_tokens = counter.most_common(20)
percentages = [count/sum(counter.values())*100 for _, count in top_20_tokens]

# Create a dataframe
df = pd.DataFrame({"token": [token for token, _ in top_20_tokens], "percentage": percentages})

# Plot the dataframe
fig = px.bar(df, x="token", y="percentage", title="Top 20 tokens in the clean dataset")
fig.show()

# now do top 20 tokens that contain a ( or )
# subsellect the counter for tokens with parentheses, token_parens
counter_parens = Counter({token: count for token, count in counter.items() if token in tokens_parens})

# plot the top 20 tokens with parentheses and use percentages
top_20_tokens_parens = counter_parens.most_common(20)
percentages_parens = [count/sum(counter_parens.values())*100 for _, count in top_20_tokens_parens]

# Create a dataframe
df_parens = pd.DataFrame({"token": [token for token, _ in top_20_tokens_parens], "percentage": percentages_parens})

# Plot the dataframe
fig = px.bar(df_parens, x="token", y="percentage", title="Top 20 tokens with parentheses in the clean dataset")
fig.show()




[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

100%|██████████| 41037/41037 [00:12<00:00, 3359.81it/s]


# Run clean data through model, look at predictions over paren tokens

In [257]:
import torch

# get random entry from chunked_data
entry_idx = np.random.randint(len(chunked_data))
code_string = chunked_data[entry_idx]

# tokenize the entry
tokens = tokenizer.tokenize(code_string)
logits = model(code_string, return_type="logits")

# compute log probs from logits
probs = torch.softmax(logits, dim=-1)

# get the log prob for the top_close_parens_idx
probs_close_parens = probs[0,:,top_close_parens_idx] # shape (num_tokens, num_close_parens)
close_paren_index = tokenizer.convert_tokens_to_ids(")")
print(close_paren_index)
# get the log prob for the close_paren_index
probs_top_close_paren = probs[0, :, close_paren_index]

# sum over the tokens
probs_close_parens_sum = probs_close_parens.sum(dim=1) # shape (num_tokens,)


print(probs_top_close_paren.shape)

print(code_string)

from html import escape
import colorsys

from IPython.display import display, HTML
def create_html(strings, values, tokenizer):
    # escape strings to deal with tabs, newlines, etc.
    strings = [tokenizer.convert_tokens_to_string(x) for x in strings]
    escaped_strings = [escape(s, quote=True) for s in strings]
    processed_strings = [s.replace('\n', '<br/>').replace('\t', '&emsp;').replace(" ", "&nbsp;") for s in escaped_strings]


    # scale values
    max_value = max(max(values), -min(values))
    scaled_values = [v / max_value * 0.5 for v in values]

    # create html
    html = ""
    for s, v in zip(processed_strings, scaled_values):
        if v < 0:
            hue = 0  # hue for red in HSV
        else:
            hue = 0.66  # hue for blue in HSV
        rgb_color = colorsys.hsv_to_rgb(hue, v, 1) # hsv color with hue 0.66 (blue), saturation as v, value 1
        hex_color = '#%02x%02x%02x' % (int(rgb_color[0]*255), int(rgb_color[1]*255), int(rgb_color[2]*255))
        html += f'<span style="background-color: {hex_color}; border: 1px solid lightgray; font-size: 16px; border-radius: 3px;">{s}</span>'

    display(HTML(html))

print("log probs of close parentheses")

create_html(tokens, 5+ torch.log(probs_close_parens_sum), tokenizer)

print("log probs of close parentheses, only )")
create_html(tokens, 5+ torch.log(probs_top_close_paren), tokenizer)

print("probs of close parentheses")
create_html(tokens, probs_close_parens_sum, tokenizer)

print("probs of close parentheses, only )")
create_html(tokens, probs_top_close_paren, tokenizer)


# make a plot with the stringed token on the x axis and the probs of each of the top close parentheses on the y axis
# probs_close_parens is shape (num_tokens, num_close_parens) and the close_parens are given by top_close_parens

# make a df with a column for each of the top close parentheses, use the top_close_parens as the column name
df = pd.DataFrame(probs_close_parens.detach().cpu().numpy(), columns=[tokenizer.convert_tokens_to_string([token]) for token in top_close_parens])

# plot the df using plotly, with a different line for every column
fig = px.line(df, x=df.index, y=df.columns, title="Probabilities of top close parentheses")
fig.show()

11
torch.Size([59])
        irange = img.findOptimalRange(hist, edges, 1 / 256)
        rgb8 = img.DataArray2RGB(data, irange)

    # save to file
    scipy.misc.imsave(filename, rgb8)

log probs of close parentheses


log probs of close parentheses, only )


probs of close parentheses


probs of close parentheses, only )


# Verify that the model is good at the task



In [313]:
# go through the dataset and calculate ( minus ) for every position

supressing_symbols = ["(", "+", "_", ".", ","]
all_cumsums = []
for code_text in tqdm.notebook.tqdm(chunked_data):

    # tokenize the code
    tokens = tokenizer.tokenize(code_text)

    # if a token contains "(" then +1 to the cumsum
    # if a token contains ")" then -1 to the cumsum

    # initialize the cumsum
    open_inds = ["(" in x for x in tokens]
    close_inds = [")" in x for x in tokens]
    suppress_inds = [x in supressing_symbols for x in tokens]

    cumsum = np.cumsum(open_inds) - np.cumsum(close_inds)
    # shift cumsum one to the right
    cumsum = np.concatenate([[0], cumsum])

    # if there is a suppressing symbol, set the cumsum to 0 in the next position
    for i in range(1, len(suppress_inds)):  # Start from 1 as we can't suppress before the first token
        if suppress_inds[i-1]:  # If the previous token is a suppressing symbol
            cumsum[i] = 0

    all_cumsums.append(cumsum)

  




  0%|          | 0/41037 [00:00<?, ?it/s]

In [314]:
create_html(tokens, cumsum, tokenizer)

In [351]:
# go through all code texts and check how the prob of ")" changes with the cumsum
# get the log prob for the top_close_parens_idx
all_dfs = []
for i, code_chunk in tqdm.notebook.tqdm(enumerate(chunked_data)):

    # tokenize the entry
    tokens = tokenizer.tokenize(code_chunk)
    logits = model(code_chunk, return_type="logits")
    probs = torch.softmax(logits, dim=-1)

    # get the log prob for the top_close_parens_idx
    close_paren_index = tokenizer.convert_tokens_to_ids(")")
    # get the log prob for the close_paren_index
    probs_top_close_paren = torch.log(probs[0, :, close_paren_index])
    df = pd.DataFrame({"cumsum": all_cumsums[i], "probs_close_parens": probs_top_close_paren.detach().cpu().numpy()})
    all_dfs.append(df)

# concatenate all the dfs
all_dfs = pd.concat(all_dfs)
    





0it [00:00, ?it/s]

ValueError: All arrays must be of the same length

In [353]:
# get means
all_dfs = pd.concat(all_dfs)
means = all_dfs.groupby("cumsum").mean()
means = means.reset_index()
print(means)

   cumsum  probs_close_parens
0      -1          -11.110972
1       0          -11.674874
2       1           -4.509087
3       2           -2.693418


In [350]:
import plotly.graph_objects as go

# Separate data based on cumsum values
df_0 = all_dfs[all_dfs['cumsum'] == 0]
df_1 = all_dfs[all_dfs['cumsum'] == 1]
df_2 = all_dfs[all_dfs['cumsum'] == 2]
bins = np.linspace(0, 1, 10)

# use plotly express to make histogram
fig = go.Figure()
fig.add_trace(go.Histogram(x=df_0['probs_close_parens'], name='cumsum = 0', histnorm='probability', xbins=dict(start=0, end=1, size=0.1)))
fig.add_trace(go.Histogram(x=df_1['probs_close_parens'], name='cumsum = 1', histnorm='probability', xbins=dict(start=0, end=1, size=0.1)))
fig.add_trace(go.Histogram(x=df_2['probs_close_parens'], name='cumsum = 2', histnorm='probability', xbins=dict(start=0, end=1, size=0.1)))
fig.update_layout(barmode='overlay', title='Probability of close parentheses for different cumsum values')
fig.update_traces(opacity=0.75)
# y log
#fig.update_layout(yaxis_type="log")
# xlabel
fig.update_layout(xaxis_title="Probability of close parentheses")
# ylabel = 

fig.show()