In [1]:
import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache

In [2]:
model = HookedTransformer.from_pretrained("gelu-4l")
model("Hello world!")


Loaded pretrained model gelu-4l into HookedTransformer


tensor([[[ 6.3408, -6.8591, -6.8589,  ..., -6.8638, -6.8522, -6.8653],
         [ 2.8462, -6.8532, -6.8352,  ..., -6.8186, -6.8169, -6.8409],
         [ 4.6911, -6.8472, -6.8372,  ..., -6.8757, -6.8512, -6.8663],
         [12.6366, -6.6353, -6.6095,  ..., -6.6533, -6.6362, -6.6393]]],
       device='cuda:0', grad_fn=<AddBackward0>)

In [3]:
model.cfg


HookedTransformerConfig:
{'act_fn': 'gelu',
 'attention_dir': 'causal',
 'attn_only': False,
 'attn_types': None,
 'checkpoint_index': None,
 'checkpoint_label_type': None,
 'checkpoint_value': None,
 'd_head': 64,
 'd_mlp': 2048,
 'd_model': 512,
 'd_vocab': 48262,
 'd_vocab_out': 48262,
 'device': 'cuda',
 'eps': 1e-05,
 'final_rms': False,
 'from_checkpoint': False,
 'gated_mlp': False,
 'init_mode': 'gpt2',
 'init_weights': False,
 'initializer_range': 0.035355339059327376,
 'model_name': 'GELU_4L512W_C4_Code',
 'n_ctx': 1024,
 'n_devices': 1,
 'n_heads': 8,
 'n_layers': 4,
 'n_params': 12582912,
 'normalization_type': 'LNPre',
 'original_architecture': 'neel',
 'parallel_attn_mlp': False,
 'positional_embedding_type': 'standard',
 'rotary_dim': None,
 'scale_attn_by_inverse_layer_idx': False,
 'seed': None,
 'tokenizer_name': 'NeelNanda/gpt-neox-tokenizer-digits',
 'use_attn_result': False,
 'use_attn_scale': True,
 'use_hook_tokens': False,
 'use_local_attn': False,
 'use_split_q

In [4]:
dir(model)

['OV',
 'QK',
 'T_destination',
 'W_E',
 'W_E_pos',
 'W_K',
 'W_O',
 'W_Q',
 'W_U',
 'W_V',
 'W_in',
 'W_out',
 'W_pos',
 '__annotations__',
 '__call__',
 '__class__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattr__',
 '__getattribute__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__setstate__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 '_apply',
 '_backward_hooks',
 '_backward_pre_hooks',
 '_buffers',
 '_call_impl',
 '_forward_hooks',
 '_forward_hooks_with_kwargs',
 '_forward_pre_hooks',
 '_forward_pre_hooks_with_kwargs',
 '_get_backward_hooks',
 '_get_backward_pre_hooks',
 '_get_name',
 '_is_full_backward_hook',
 '_load_from_state_dict',
 '_load_state_dict_post_hooks',
 '_load_state_dict_pre_hooks',
 '_maybe_warn_non_full_backward_hook',
 '_modules',
 '_named_members',

In [5]:
d_head = model.cfg.d_head
d_mlp = model.cfg.d_mlp
d_model = model.cfg.d_model
d_vocab = model.cfg.d_vocab
n_ctx = model.cfg.n_ctx
n_heads = model.cfg.n_heads
n_layers = model.cfg.n_layers

# make a nice printout of the above values
print(f"d_head: {d_head}\nd_mlp: {d_mlp}\nd_model: {d_model}\nd_vocab: {d_vocab}\nn_ctx: {n_ctx}\nn_heads: {n_heads}\nn_layers: {n_layers}")


d_head: 64
d_mlp: 2048
d_model: 512
d_vocab: 48262
n_ctx: 1024
n_heads: 8
n_layers: 4


In [6]:
tokenizer = model.tokenizer
# Assume 'tokenizer' is the tokenizer used by your model
all_tokens = list(tokenizer.get_vocab().keys())

# Filter the list to include only tokens that contain ")"
tokens_with_bracket = [token for token in all_tokens if ")" in token]
print(f"Number of tokens containing ')': {len(tokens_with_bracket)}")

tokens_with_close_bracket = [token for token in all_tokens if "(" in token]
print(f"Number of tokens containing '(': {len(tokens_with_close_bracket)}")

# do any tokens contain both ( and )?
tokens_with_both_brackets = [token for token in all_tokens if "(" in token and ")" in token]
print(f"Number of tokens containing both '(': {len(tokens_with_both_brackets)}")

# all tokens with parens
tokens_with_all_brackets = [token for token in all_tokens if "(" in token or ")" in token]

# print the tokens that contain both
# make it nicely formatted with tabs and newlines
# 8 examples per line
print("\n".join(["\t".join(tokens_with_both_brackets[i : i + 8]) for i in range(0, len(tokens_with_both_brackets), 8)]))

Number of tokens containing ')': 354
Number of tokens containing '(': 266
Number of tokens containing both '(': 61
)|$(	)/((-	)}(	)/((	Ġ()	>();	)[(	()).
Ġ().	Ġ(+)	)*(-	)}(\	(),	():	))/((	)-(
()">	)**(	()"	))/(	(()	()))	)âĢĵ(	Ġ(),
)=(	)+(	<>();	Ġ()]{}	Ġ()](\	))/((-	})(	))/(-
()->	)**(-	)](	()`	)/(	()));	),(	))**(
>()	());	)*(	(+)	)(	()[	)](#	().
Ġ(%)	)/(-	).](	(){	()	()</	)}^{(	Ġ();
())	();	)(\	()),	))**(-


In [7]:
# Import the load_dataset function from the Hugging Face's datasets library
from datasets import load_dataset
import tqdm

# Load the "NeelNanda/code-10k" dataset and take the training split
code_data = load_dataset("NeelNanda/code-10k", split="train")

from collections import Counter

# Initialize a counter
counter = Counter()

# Iterate over the dataset and tokenize each entry
for entry in tqdm.tqdm(code_data):
    # Tokenize the code
    tokens = tokenizer.tokenize(entry["text"])
    
    # Update the counter
    counter.update(tokens)

# Now, filter the counter to get counts of the desired tokens
tokens_with_bracket_counts = {token: count for token, count in counter.items() if ")" in token}
tokens_with_close_bracket_counts = {token: count for token, count in counter.items() if "(" in token}
tokens_with_both_brackets_counts = {token: count for token, count in counter.items() if "(" in token and ")" in token}

print(f"Distribution of tokens containing ')': {tokens_with_bracket_counts}")
print(f"Distribution of tokens containing '(': {tokens_with_close_bracket_counts}")
print(f"Distribution of tokens containing both '(': {tokens_with_both_brackets_counts}")


Found cached dataset parquet (C:/Users/adamimos/.cache/huggingface/datasets/NeelNanda___parquet/NeelNanda--code-10k-80b300d967669109/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)
100%|██████████| 10000/10000 [01:34<00:00, 106.29it/s]

Distribution of tokens containing ')': {')': 278592, '):': 145964, '])': 21368, '()': 88462, "')": 49040, 'Ġ...)': 133, '():': 13324, '),': 38552, ')]': 3415, "'):": 4749, '))': 36801, ')))': 4044, '").': 1784, "'))": 5100, ']):': 932, ').': 16166, "'])": 7043, "'),": 11453, '")': 29976, 'Ġ()': 395, ')[': 2886, '())': 8747, ']))': 2575, ')):': 2882, ']).': 872, "Ġ')": 1000, "']))": 817, '})': 5352, '().': 6991, "');": 247, '.")': 2156, ':")': 260, '(),': 7406, ')",': 623, 'Ġ")': 667, "Ġ'')": 933, '"])': 1778, '+)': 543, ')(': 1279, '>)': 81, ')=(': 11, '()))': 1073, '");': 1144, ')}': 686, ')).': 787, 'Ġ)': 18767, ')**': 407, ")',": 723, ")'": 894, 'Ġ})': 1109, ')),': 3895, ')."': 195, 'Ġ):': 2388, ')"': 1540, ');': 2567, '")]': 211, '__)': 1215, "').": 2631, ')/(': 278, '"))': 3912, '()).': 323, ')+': 2167, ')/': 1368, ')$': 261, 'Ġ).': 731, '.)': 778, '()`': 205, ',)': 640, '"),': 5822, 'Ġ),': 3827, ')],': 665, '()[': 1355, '}).': 67, '();': 617, 'Ġ});': 43, '}})': 174, ']);': 193, '




In [8]:
# Print the first entry to see the keys
# plot the distribution over tokens containing ')' with plotly express
# make it a prob distribution, have it so that when i hover it shows the
# token and the probability, also ordered
import plotly.express as px

# Sort the tokens by their counts
sorted_tokens_with_bracket_counts = sorted(tokens_with_bracket_counts.items(), key=lambda x: x[1], reverse=True)

# Get the counts
counts = [count for _, count in sorted_tokens_with_bracket_counts]

# Normalize the counts
counts = [100*count / sum(counts) for count in counts]

# Get the tokens
tokens = [token for token, _ in sorted_tokens_with_bracket_counts]

# Plot the distribution
fig = px.bar(x=tokens, y=counts)

# Set the x-axis title
fig.update_xaxes(title_text="Token")

# Set the y-axis title
fig.update_yaxes(title_text="Prevelance (%)")

# zoom in to only show tokens with prob > 0.1%
# find the index of the first token with prob < 0.1%
for i, count in enumerate(counts):
    if count < 1.:
        break

# Set the x-axis range
fig.update_xaxes(range=[0, i])

# add title
fig.update_layout(title_text="Distribution of tokens containing ')'")
fig.show()

# now lets do the same for tokens containing '('
# Sort the tokens by their counts
sorted_tokens_with_close_bracket_counts = sorted(tokens_with_close_bracket_counts.items(), key=lambda x: x[1], reverse=True)

# Get the counts
counts = [count for _, count in sorted_tokens_with_close_bracket_counts]

# Normalize the counts
counts = [100*count / sum(counts) for count in counts]

# Get the tokens
tokens = [token for token, _ in sorted_tokens_with_close_bracket_counts]

# Plot the distribution
fig = px.bar(x=tokens, y=counts)

# Set the x-axis title
fig.update_xaxes(title_text="Token")

# Set the y-axis title
fig.update_yaxes(title_text="Prevelance (%)")

# zoom in to only show tokens with prob > 0.1%
# find the index of the first token with prob < 0.1%
for i, count in enumerate(counts):
    if count < 1.0:
        break

# Set the x-axis range
fig.update_xaxes(range=[0, i])

# add title
fig.update_layout(title_text="Distribution of tokens containing '('")
fig.show()

# now lets do the same for tokens containing both '(' and ')'




In [9]:
# how much of the dataset contains only "(" and ")", and none of the other parenthesis tokens?

clean_tokens = [")", "):", "')", "),", "))",
                "(", " (", "('"]
# dirty tokens are tokens_with_all_brackets except we have to take out the clean tokens
dirty_tokens = set(tokens_with_all_brackets) - set(clean_tokens)

print(f"Number of all tokens: {len(all_tokens)}")
print(f"Number of tokens with all brackets: {len(tokens_with_all_brackets)}")
print(f"Number of clean tokens: {len(clean_tokens)}")
print(f"Number of dirty tokens: {len(dirty_tokens)}")

# build a pandas dataframe to keep track of data_index, num_clean, num_dirty
# now lets go through all the data
import pandas as pd
df = pd.DataFrame(columns=["data_index", "num_clean", "num_dirty"])
# use tqdm for notebook progress bar
for ind, entry in tqdm.notebook.tqdm(enumerate(code_data)):
    # Tokenize the code
    tokens = tokenizer.tokenize(entry["text"])

    # get the number of clean tokens
    num_clean = sum([token in clean_tokens for token in tokens])

    # get the number of dirty tokens
    num_dirty = sum([token in dirty_tokens for token in tokens])

    # add the data to the dataframe, use pandas.concat
    df = pd.concat([df, pd.DataFrame([[ind, num_clean, num_dirty]], columns=["data_index", "num_clean", "num_dirty"])])




Number of all tokens: 48262
Number of tokens with all brackets: 559
Number of clean tokens: 8
Number of dirty tokens: 552


0it [00:00, ?it/s]

In [10]:
# how many examples have at least one clean token? and 0 dirty tokens?

# get the number of examples with at least one clean token
num_examples_with_at_least_one_clean_token = len(df[df["num_clean"] > 0])

# get the number of examples with 0 dirty tokens
num_examples_with_zero_dirty_tokens = len(df[df["num_dirty"] == 0])

print(f"Number of examples with at least one clean token: {num_examples_with_at_least_one_clean_token}")
print(f"Number of examples with 0 dirty tokens: {num_examples_with_zero_dirty_tokens}")

# how many examples have at least one clean token? and 0 dirty tokens?
num_clean_only = len(df[(df["num_clean"] > 0) & (df["num_dirty"] == 0)])

print(f"Number of examples with at least one clean token and 0 dirty tokens: {num_clean_only}")

Number of examples with at least one clean token: 9985
Number of examples with 0 dirty tokens: 35
Number of examples with at least one clean token and 0 dirty tokens: 30


In [11]:
# how many tokens is a random example

# get the first data
data = code_data[0]

# tokenize the data
tokens = tokenizer.tokenize(data["text"])

# print the number of tokens
print(f"Number of tokens in the first data: {len(tokens)}")

Number of tokens in the first data: 8677


In [31]:
# do this for first 100 examples

# get the first 100 data
data = code_data[0]

text = data['text']

num_newlines = text.count('\n')
print(f"Number of newlines in the first data: {num_newlines}")

# chunk the data by newlines
chunks = text.split('\n')

# print the number of chunks
print(f"Number of chunks in the first data: {len(chunks)}")




Number of newlines in the first data: 833
Number of chunks in the first data: 834
Number of chunks in the first data with at least one '(': 291
# Copyright (c) 2019 J. Alvarez-Jarreta and C.J. Brasher
"""Graphical User Interface (GUI) to manage the parameters' collection.
class _TaggedToggleButton(widgets.ToggleButton):
    def __init__(self, tag, **kwargs):
        widgets.ToggleButton.__init__(self, **kwargs)
class _TaggedCheckbox(widgets.Checkbox):
    def __init__(self, tag, **kwargs):
        widgets.Checkbox.__init__(self, **kwargs)
class _TaggedButton(widgets.Button):
    def __init__(self, tag, **kwargs):
        widgets.Button.__init__(self, **kwargs)
class LFParametersGUI(LFParameters):
        _parameters  (Private[collections.OrderedDict])
        _floatPointPrecision  (Private[int])
        _floatStep  (Private[float])
        _style  (Private[dict])
        _inputWidth  (Private[str])
        _widgets  (Private[collections.OrderedDict])
            >>> LFParametersGUI()
 

In [39]:
import pandas as pd

import itertools

# make a function that takes in text data and splits it into chunks by \n
def chunk_by_newline(text):
    # split the text by \n
    chunks = text.split('\n')

    # return the chunks
    return chunks

def filter_chunks(chunks):
    """
    chunks is a list of strings
    we only want chunks that have at least one (
    and we also want to remove chunks that have a ) that precedes the first (
    """

    # initialize a list to keep track of the filtered chunks
    filtered_chunks = []

    # loop through each chunk
    for chunk in chunks:
        # check if the chunk has at least one (
        if '(' in chunk:
            # check if the chunk has a ) that precedes the first (
            if ')' not in chunk[:chunk.index('(')]:
                # add the chunk to the filtered_chunks
                filtered_chunks.append(chunk)

    # return the filtered_chunks
    return filtered_chunks


# now lets go through the dataset and chunka and filter things!

# initialize a list to keep track of the filtered data
filtered_data = []

# use tqdm for notebook progress bar
for ind, entry in tqdm.notebook.tqdm(enumerate(code_data)):
    # get the text
    text = entry["text"]

    # chunk the text
    chunks = chunk_by_newline(text)

    # filter the chunks
    filtered_chunks = filter_chunks(chunks)

    # add the filtered chunks to the filtered data
    filtered_data.append(filtered_chunks)

# filtered_data is a list of lists, but i want it to just be a list
# use itertools.chain.from_iterable
filtered_data = list(itertools.chain.from_iterable(filtered_data))


# how many entries doest filtered_data have?
print(f"Number of entries in filtered_data: {len(filtered_data)}")


0it [00:00, ?it/s]

Number of entries in filtered_data: 776720


In [46]:
# now lets go through our filtered data and tokenize it, and check for
# clean vs dirty tokens

clean_tokens = [")", "("]
dirty_tokens = set(tokens_with_all_brackets) - set(clean_tokens)

valid_example = 0

clean_data = []
# use tqdm for notebook progress bar
pbar = tqdm.notebook.tqdm(total=len(filtered_data))
for ind, entry in tqdm.notebook.tqdm(enumerate(filtered_data)):
    
    # does this entry have at least one clean token and no dirty tokens?
    # tokenize the entry
    tokens = tokenizer.tokenize(entry)

    # get the number of clean tokens
    num_clean = sum([token in clean_tokens for token in tokens])

    # get the number of dirty tokens
    num_dirty = sum([token in dirty_tokens for token in tokens])

    if num_clean > 0 and num_dirty == 0:
        valid_example += 1
        clean_data.append(entry)

    # update the progress bar with the percentage of valid examples
    #pbar.update(1)
    # pbar.set_description(f"Percentage of valid examples: {valid_example / (ind + 1)}")

print(f"Number of valid examples: {valid_example}")





  0%|          | 0/776720 [00:00<?, ?it/s]

0it [00:00, ?it/s]

Number of valid examples: 203295


In [47]:
# save clean_data, which is a list of strings, as a csv
import csv

with open('clean_data.csv', 'w') as f:
    writer = csv.writer(f)
    writer.writerow(clean_data)

# how big is the csv file, in MB?
import os

print(f"Size of clean_data.csv: {os.path.getsize('clean_data.csv') / 1e6} MB")

Size of clean_data.csv: 10.176664 MB


In [None]:
def chunk_by_paren_closure(text):
    """This function takes in text and splits it into chunks by paren closure
    that means it finds the first open paren and then goes until the valid closing paren
    """

    


# initialize a list to keep track of the filtered data
filtered_data = []

# use tqdm for notebook progress bar
for ind, entry in tqdm.notebook.tqdm(enumerate(code_data)):
    # get the text
    text = entry["text"]

    # chunk the text
    chunks = chunk_by_newline(text)

    # filter the chunks
    filtered_chunks = filter_chunks(chunks)

    # add the filtered chunks to the filtered data
    filtered_data.append(filtered_chunks)

# filtered_data is a list of lists, but i want it to just be a list
# use itertools.chain.from_iterable
filtered_data = list(itertools.chain.from_iterable(filtered_data))


# how many entries doest filtered_data have?
print(f"Number of entries in filtered_data: {len(filtered_data)}")

In [136]:
# get the first text
text = code_data[0]["text"]

def chunk_code(code):
    lines = code.split('\n')
    chunks = []
    chunk = ''
    paren_count = 0
    
    for line in lines:
        paren_count += line.count('(') - line.count(')')
        chunk += line + '\n'
        # If all parentheses are closed and the line had parentheses
        if paren_count == 0 and ('(' in line or ')' in line):
            chunks.append(chunk)
            chunk = ''
    
    # After going through all lines, if there's still some chunk left, append it.
    if chunk:
        chunks.append(chunk)

    return chunks

def combine_chunks(chunks):
    combined = []
    combine_pattern = [1, 2, 3]
    
    while chunks:
        for combine_count in combine_pattern:
            if len(chunks) < combine_count:
                # If not enough chunks remain, append remaining chunks and end loop
                combined.append(''.join(chunks))
                chunks = []
                break
            else:
                # Combine chunks according to pattern and remove them from chunk list
                combined.append(''.join(chunks[:combine_count]))
                chunks = chunks[combine_count:]
    
    return combined




chunks = chunk_code(text)

print(text)

# Copyright (c) 2019 J. Alvarez-Jarreta and C.J. Brasher
#
# This file is part of the LipidFinder software tool and governed by the
# 'MIT License'. Please see the LICENSE file that should have been
# included as part of this software.
"""Graphical User Interface (GUI) to manage the parameters' collection.
"""

from collections import OrderedDict
import os

from IPython.display import display
from ipywidgets import widgets, Layout
import pandas

from LipidFinder.Configuration import LFParameters
from LipidFinder._utils import normalise_path


class _TaggedToggleButton(widgets.ToggleButton):
    """Add "tag" attribute to widgets.ToggleButton class."""

    def __init__(self, tag, **kwargs):
        widgets.ToggleButton.__init__(self, **kwargs)
        self.tag = tag


class _TaggedCheckbox(widgets.Checkbox):
    """Add "tag" attribute to widgets.Checkbox class."""

    def __init__(self, tag, **kwargs):
        widgets.Checkbox.__init__(self, **kwargs)
        self.tag = tag


class _Ta

In [207]:
chunked_data = []

# go through all the data
pbar = tqdm.notebook.tqdm(total=len(code_data))
for ind, entry in tqdm.notebook.tqdm(enumerate(code_data)):

    # get the text for this entry
    text = entry["text"]

    # chunk the text
    chunks = chunk_code(text)

    # combine the chunks
    combined_chunks = combine_chunks(chunks)

    # go through the combined chunks and tokenize
    for chunk in combined_chunks:
        # tokenize the chunk
        tokens = tokenizer.tokenize(chunk)

        # get the number of clean tokens
        num_clean = sum([token in clean_tokens for token in tokens])

        # get the number of dirty tokens
        num_dirty = sum([token in dirty_tokens for token in tokens])

        if num_clean > 0 and num_dirty == 0:
            chunked_data.append(chunk)









  0%|          | 0/10000 [00:00<?, ?it/s]

0it [00:00, ?it/s]

In [208]:
# how many entries does chunked_data have?
print(f"Number of entries in chunked_data: {len(chunked_data)}")

Number of entries in chunked_data: 41037


In [209]:
import random
# print some examples
for j in range(1):
    # random index
    i = random.randint(0, len(chunked_data) - 1)
    print(chunked_data[i])

    # Authors
    authors = models.CharField(max_length=250)
    # Publication date
    publicationDate = models.DateField(editable=True, auto_now=False, auto_now_add=False)



In [210]:
import csv

# assuming chunked_data is your list of strings

# specify the name of the csv file where the data will be saved
filename = 'multiline_data.csv'

# writing to csv file  
with open(filename, 'w') as csvfile:  
    # creating a csv writer object  
    csvwriter = csv.writer(csvfile)  
        
    # writing each chunk as a row in the csv file
    for chunk in chunked_data:
        csvwriter.writerow([chunk]) 

In [211]:
# what is the size of the csv file?
import os

print(f"Size of multiline_data.csv: {os.path.getsize('multiline_data.csv') / 1e6} MB")

Size of multiline_data.csv: 5.67938 MB


In [212]:
code_text_file = open("multiline_data.csv", "r")
code_text = code_text_file.read()
code_text_file.close()

# get the number of characters in the code text
num_chars = len(code_text)
print(f"Number of characters in code text: {num_chars}")

Number of characters in code text: 5494734


In [213]:
import numpy as np
np.sum([len(x) for x in chunked_data])

5307135

In [214]:
code_text[0]

'"'

In [215]:
len(chunked_data)

41037

In [216]:
import pandas as pd

# Convert the list into a DataFrame
df = pd.DataFrame(chunked_data, columns=['Chunks'])

# Write the DataFrame to a CSV file
df.to_csv('chunks.csv', index=False)

In [217]:
# load the csv file
df = pd.read_csv('chunks.csv')

# convert to list
chunked_data = df['Chunks'].tolist()

In [218]:
len(chunked_data)

41037

In [237]:
import torch

logits = model(chunked_data[0]) # batch, seq_len, n_vocab
prob = torch.softmax(logits, dim=2) # batch, seq_len, n_vocab

token_index_OPEN_PAREN = tokenizer.convert_tokens_to_ids('(')
token_index_CLOSE_PAREN = tokenizer.convert_tokens_to_ids(')')

print(f"token_index_OPEN_PAREN: {token_index_OPEN_PAREN}")
print(f"token_index_CLOSE_PAREN: {token_index_CLOSE_PAREN}")

print(100*prob[0,:,token_index_CLOSE_PAREN])

# make a plot using plotly express to visualize the probabilities
# with the x axis being the string tokenized
# and the y axis being the probability of the token being a close paren
import plotly.express as px

string = chunked_data[0]
tokens = tokenizer.tokenize(string)

fig = px.line(prob[0,:,token_index_CLOSE_PAREN].detach().cpu().numpy())
fig.show()



token_index_OPEN_PAREN: 10
token_index_CLOSE_PAREN: 11
tensor([3.4030e-01, 1.0094e-03, 5.0231e-01, 1.3257e-04, 2.3730e-05, 4.6753e-04,
        9.4137e-04, 7.2276e-04, 7.6869e-04, 8.5612e-06, 5.5276e-09, 2.2622e-05,
        9.4735e-08, 2.8002e-05, 7.0168e-03, 1.3107e-03, 2.4875e+01, 6.7558e+01,
        6.5559e-07, 1.2075e-04, 3.2383e-05, 4.1141e-05, 4.8454e-05, 2.2031e-02],
       device='cuda:0', grad_fn=<MulBackward0>)


In [226]:
vocab_dict = tokenizer.get_vocab()
print(f"Number of tokens in vocab: {len(vocab_dict)}")

Number of tokens in vocab: 48262


In [227]:
# tokenize the first example
tokens = tokenizer.tokenize(chunked_data[0])
print(f"Number of tokens in first example: {len(tokens)}")

Number of tokens in first example: 23
