In [1]:
# !pip install datasets
# !pip install regex

In [2]:
from datasets import load_dataset

# Loading Wikipedia dataset, just 1% of the data
dataset = load_dataset("wikipedia", "20220301.en", split="train[:1%]")

# Saving as text file
with open("wikipedia_train.txt", "w", encoding='utf-8') as f:
    for item in dataset:
        f.write(item["text"].replace("\n", " ") + "\n")

## Motivation for BPE Tokenization
We are implementing a Byte Pair Encoding Tokenizer. There are a range of tokenizing options like Word Piece tokenizer or a simple character lookup tokenizer. These tokenization techniques are simpler to implement, however they are ineffecient because they use more tokens to represent sentences when compared to BPE tokenizer.

BPE tokenizer is a state of the art technique to represent sentences and it is used in models like GPT-2. The high level overview of BPE is that, frequently occuring tokens will be merged to form a singular new token rather represnting the component parts as multiple individual tokens. To learn more about BPE, check this out: https://en.wikipedia.org/wiki/Byte_pair_encoding

### This is notebook primarily focussed on training the Tokenizer and to show the process more intuitively. Once trained, this tolenizer can be called from a Tokenizer class which is defined in the tokenizer.py file

In [4]:
import os
import regex as re
from tqdm import tqdm
from collections import Counter, defaultdict

# Config
file_path = 'wikipedia_train.txt'
chunk_size = 10000  # read the file in chunks of 0.01 MB to keep the calculations in memory.
vocab_size = 997 # ideally the vocab size would be 1000 after we add CLS and PAD tokens.
orig_vocab_size = 255 # since we are using utf-8, values 0-255 are blocked.
new_token = orig_vocab_size  # values from 256-997 are free to be assigned.
merged = {} # holds the combination of tokens that merge to form a new token.

# Tokenizer Pattern
# This pattern is a slight modification of the GPT-2 pattern to split sentences based on space characters. 
# This modification is necessary for downstream tasks like NER where spaces between words are better left uncombined.
pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d|\p{L}+|\p{N}+|[^\s\p{L}\p{N}]+|\s+""")


# Read text file in chunks. And tokenize the text chunks. Then update the tokens in a global list, token_seqs. 
# token_seqs is a list of list of tokens. This is because, when the text goes through the above regex pattern, it is split into lists.
# Then Tokenization & token merging will happen inside these sublists. 
# This sentence representation is more suitable for downstream tasks like NER.

file_size = os.path.getsize(file_path)
token_seqs = []

with open(file_path, 'r', encoding='utf-8') as f, tqdm(total=file_size, unit='B', unit_scale=True, desc="Tokenizing") as pbar:
    while True:
        chunk = f.read(chunk_size)
        if not chunk:
            break
        texts = re.findall(pat, chunk)
        token_seqs.extend([list(token.encode('utf-8')) for token in texts])
        pbar.update(len(chunk))


def build_pair_stats(seqs):
    '''
    Stats to calculate the count of token-pairs and their indices in token_seqs. 
    Knowing token-pairs' indices will help in faster updates during merging.
    '''
    
    pair_counts = Counter()
    pair_to_indices = defaultdict(set)

    with tqdm(total=len(seqs), desc="Counting Pairs", unit="seqs") as pbar:
        for idx, seq in enumerate(seqs):
            for i in range(len(seq) - 1):
                pair = (seq[i], seq[i + 1])
                pair_counts[pair] += 1
                pair_to_indices[pair].add(idx)
            pbar.update(1)

    return pair_counts, pair_to_indices


def merge_sequence(seq, tok_1, tok_2, new_token):
    '''
    Function to merge tokens.
    For a given sequence, look for the token pair (tok_1, tok_2) that needs to be updated.
    Then update the token pair with the new token (new_token).
    Return the merged sequence.
    '''
    
    merged_seq = []
    i = 0
    while i < len(seq):
        if i < len(seq) - 1 and seq[i] == tok_1 and seq[i + 1] == tok_2:
            merged_seq.append(new_token)
            i += 2
        else:
            merged_seq.append(seq[i])
            i += 1
    return merged_seq

# Initial Pair Stats
pair_counts, pair_to_indices = build_pair_stats(token_seqs)

# BPE Merge Loop
# The merge loop will run n number of times, where n is the number of new tokens to be created.
# Since we need vocab of 997 and we already have 0-225 vocab values assigned, run this loop 997-255 times to generate 998 len vocab.
for _ in tqdm(range(vocab_size - orig_vocab_size), desc="Merging"):
    if not pair_counts:
        break

    # Find most frequent unmerged pair
    for (tok_1, tok_2), _ in pair_counts.most_common():
        if (tok_1, tok_2) not in merged:
            break

    # Create new token value and store it in the merged map
    new_token += 1
    merged[(tok_1, tok_2)] = new_token

    # For the token to be merged, get all affected indices
    affected_indices = pair_to_indices.get((tok_1, tok_2), set())
    if not affected_indices:
        continue

    # Merge only affected sequences
    for idx in list(affected_indices):
        old_seq = token_seqs[idx]
        new_seq = merge_sequence(old_seq, tok_1, tok_2, new_token)
        token_seqs[idx] = new_seq
    
        # Because we merged, decrement the pair counts from "old sequence" and delete their respective indices that were updated.
        for i in range(len(old_seq) - 1):
            pair = (old_seq[i], old_seq[i + 1])
            pair_counts[pair] -= 1
            if pair_counts[pair] <= 0:
                del pair_counts[pair]
            pair_to_indices[pair].discard(idx)
            if not pair_to_indices[pair]:
                del pair_to_indices[pair]

        # Because we merged, increment the pair counts from "new sequence" and add their respective indices that were updated.
        for i in range(len(new_seq) - 1):
            pair = (new_seq[i], new_seq[i + 1])
            pair_counts[pair] += 1
            pair_to_indices[pair].add(idx)


Tokenizing: 100%|███████████████████████████▉| 897M/900M [03:00<00:00, 4.96MB/s]
Counting Pairs: 100%|██████| 315181836/315181836 [04:57<00:00, 1060373.97seqs/s]
Merging: 100%|██████████████████████████████| 743/743 [2:02:49<00:00,  9.92s/it]


In [5]:
# Create a reverse map of the merged items.
rev_merged = {v:k for k, v in merged.items()}

In [6]:
def get_counts(list_of_tokens):
    '''
    Count the frequency of adjacent token pairs across a list of token sequences.

    Notes:
        This function is a simplified version of build_pair_stats, intended for sentence-level inference,
        where maximum efficiency is not critical.
    '''
    counts = {}

    for token_list in list_of_tokens:
        for pair in zip(token_list, token_list[1:]):
            if pair not in counts:
                counts[pair] = 1
            else:
                counts[pair] += 1
    
    count_vals = [(v, k[0], k[1]) for k,v in counts.items()]
    count_vals = sorted(count_vals, reverse=True)
    return count_vals

    
def encode(text):
    '''
    Encode a text sequence by iteratively merging token pairs based on a predefined merge map.

    Process:
        - Split the input text into fragments.
        - Tokenize each fragment into byte-encoded tokens.
        - Iteratively merge token pairs found in the 'merged' mapping, 
          prioritizing pairs with the smallest assigned merged token value.
        - Continue merging until no more applicable pairs are found.

    Notes:
        Priority is given to token pairs with the lowest merged value to preserve merge coherence.
    '''
    texts = re.findall(pat, text)
    list_of_tokens = [list(text.encode('utf-8')) for text in texts]

    seen = set()
    while True:
        count_vals = get_counts(list_of_tokens)
        pairs = [(val[1:]) for val in count_vals]
        pairs = [pair for pair in pairs if pair in merged and pair not in seen]
        valid_pairs = {pair:merged[pair] for pair in pairs}
        
        if valid_pairs:
            min_pair = min(valid_pairs, key=lambda x:merged[x])
            seen.add(min_pair)
            tok_1, tok_2 = min_pair
            new_token = merged[(tok_1, tok_2)]

            new_list_of_tokens = []
    
            for tokens in list_of_tokens:
                new_tokens = merge_sequence(tokens, tok_1, tok_2, new_token)
                new_list_of_tokens.append(new_tokens)
            list_of_tokens = new_list_of_tokens
        else:
            break

    return list_of_tokens

In [7]:
def num_to_str(number):
    '''
    Convert an integer (0–255) to its corresponding single UTF-8 character.
    '''
    if number > 255:
        raise ValueError("Invalid Token")
    byte_representation = number.to_bytes(1, byteorder='big')
    utf8_string = byte_representation.decode('utf-8')
    return utf8_string

In [8]:
def split_sequence(seq, tok_1, tok_2, merged_token):
    '''
    Split merged tokens in a sequence back into their original token pairs using the 'rev_merged' mapping.
    '''
    new_tokens = []

    for token in seq:
        if token == merged_token:
            tok_1, tok_2 = rev_merged[token]
            new_tokens.append(tok_1)
            new_tokens.append(tok_2)
        else:
            new_tokens.append(token)

    return new_tokens

In [9]:
def decode(list_of_tokens):
    '''
    Decode a list of merged token sequences back into the original text string.
    
    Process:
        - Iteratively split merged tokens into their original token pairs, 
          starting from the highest merged token value.
        - Flatten the final list of tokens.
        - Convert each token back to its corresponding UTF-8 character.
        - Join the characters into a single decoded string.
    
    Notes:
        Priority is given to token pairs with the highest merged value to preserve un-merge coherence.
    '''

    while True:
        valid_tokens = []
        for tokens in list_of_tokens:
            valid_tokens.extend([token for token in tokens if token in rev_merged.keys()])

        if valid_tokens:
            max_token = max(valid_tokens)                 
            new_list_of_tokens = []
            for tokens in list_of_tokens:
                new_tokens = split_sequence(tokens, tok_1, tok_2, max_token)
                new_list_of_tokens.append(new_tokens)
            list_of_tokens = new_list_of_tokens
        else:
            break
    final_tokens = []
    for tokens in list_of_tokens:
        final_tokens.extend(tokens)
    str_tokens = [num_to_str(token) for token in final_tokens]
    return "".join(str_tokens)

In [10]:
# Here we assert that an arbitary text, once passed through the encode and decode functions, returns the same input text.
inp = '''Through the 1980s, Tarantino had a number of jobs. After lying about his age, he worked as an usher at an adult movie theater in Torrance, called the Pussycat Theater. '''
out = decode(encode(inp))
inp==out

True

In [29]:
encoded = encode(inp)
flattened = []
for t in encoded: flattened.extend(t)

print("tokenized text example  : ", flattened)
print("Length of original text : ", len(inp))
print("Length of tokenized text: ", len(flattened))
print("Tokenized efficieny     : ", round(len(inp)/ len(flattened), 2))

tokenized text example  :  [465, 114, 445, 32, 260, 32, 630, 48, 115, 44, 32, 84, 271, 379, 257, 111, 32, 430, 32, 97, 32, 952, 32, 270, 32, 106, 553, 115, 46, 32, 65, 102, 298, 32, 301, 278, 32, 805, 32, 368, 32, 381, 44, 32, 284, 32, 647, 267, 32, 273, 32, 259, 32, 306, 567, 32, 263, 32, 259, 32, 292, 508, 32, 355, 948, 32, 260, 491, 32, 257, 32, 84, 265, 500, 321, 44, 32, 881, 32, 260, 32, 80, 306, 529, 99, 263, 32, 299, 491, 46, 32]
Length of original text :  168
Length of tokenized text:  91
Tokenized efficieny     :  1.85


In [16]:
# We are saving the merged-map as a json for future reference!
import json
merged_serializable = {f"({k[0]},{k[1]})": v for k, v in merged.items()}
with open("bpe_merged.json", "w", encoding="utf-8") as f:
    json.dump(merged_serializable, f, indent=2)