In [1]:
text = """
You were There, When I need you the Most
You are the One, I Knew, when I Saw you First.

I still Remember, How we Met in the Mid of July.
Since then, You became My Hardest Goodbye.
Whenever I see you, I Fall for You, All over Again,
My life was a Desert, you Came in like a Soothing Rain.

You Always Showed me Path, Like a Flashlight,
You are the One, The Moon of This Dark Night.

They say, You came on Earth Alone, as a Half,
You Wander in Search of Other, for a Complete Laugh.
I knew, I found Mine, when I Get to know,
We need No Words to Convey, What we Feel though.

I was, I am, I'll wait in the Next Life too,
Cause My Heart Knows, that it's Other Half is You.
"""

In [2]:
# create a vocabulary dictionary
def create_vocab(text):

    char_to_idx = {}
    idx_to_char = {}

    for idx, char in enumerate(list(set(text))):
        unicode_code_point = char.encode('utf-8')

        char_to_idx[unicode_code_point] = idx
        idx_to_char[idx] = unicode_code_point

    return char_to_idx, idx_to_char

In [3]:
def encode(text):
    global char_to_idx
    
    # encode text in utf-8 to get the list of unicode integers
    byte_text = text.encode('utf-8')

    return [char_to_idx[bytes([b])] for b in byte_text]

In [4]:
def decode(unicode_code_point_list):
    global idx_to_char

    # decode list of unicode integers to raw text
    byte_seq = b''.join([idx_to_char[idx] for idx in unicode_code_point_list])

    return byte_seq.decode('utf-8')

In [5]:
# Raw unicode characters length
char_to_idx, idx_to_char = create_vocab(text)
encoded_unicode_chars = encode(text)
print(f"total tokens in raw text: {len(encoded_unicode_chars)}")
print(f"vocab size would be: {len(set(encoded_unicode_chars))}")

total tokens in raw text: 671
vocab size would be: 48


```
Assumptions:
    1. LLM will have context size of 50.

Initial observation:
    1. The raw text has 671 tokens with 48 chars in the vocabulary.
    2. With current context size of 50, we can only process 50 tokens at a time.
    3. To process the entire text using a LLM with a context window of 50 tokens, we would need 14 segments (windows). (671/50 = 13.42, rounded up to 14).

Goal:
    1. We need an algorithm that reduces the total number of tokens by merging frequently occurring character or symbol pairs, so we can ideally process the entire text in fewer windows, or even in one window.
    2. Making sure that the vocabulary size is not too large, as the vocab size determines the memory required to store token embeddings and increases model complexity. A large vocabulary can lead to sparsity and make the model harder to train.

Note: 
    1. There is a trade-off between the number of tokens and the vocabulary size, larger vocab size means more tokens can be represented, but it also increases the model complexity and memory requirements.
    2. Also larger vocab means more scarcity, as in the end steps of transformer we have to predict the next token out of our vocabulary, i.e. the overall probability distribution of the next token is over the vocabulary size.
    3. So, we need to find a balance between the number of tokens and the vocabulary size.
    4. The goal is to reduce the number of tokens while keeping the vocabulary size manageable.
```

In [None]:
# Supporting functions for bpe
def get_stats(unicode_code_point_list: list) -> dict:
    """    
        Get the frequency of each pair of consecutive bytes in the byte list.  

        Sample usage:
        >>> get_stats([1, 2, 3, 4, 1, 2, 4, 5, 6, 3, 4, 2, 3, 4])
        {
            (3, 4): 3,
            (1, 2): 2,
            (2, 3): 2,
            (4, 1): 1,
            (2, 4): 1,
            (4, 5): 1,
            (5, 6): 1,
            (6, 3): 1,
            (4, 2): 1
            }

    """
    stats = {}

    for itm1, itm2 in zip(unicode_code_point_list, unicode_code_point_list[1:]):
        pair = (itm1, itm2)

        stats[pair] = stats.get(pair, 0) + 1

    #return the stats dictionary as a dictionary but sorted in descending order of values
    return {k: v for k, v in sorted(stats.items(), key=lambda item: item[1], reverse=True)}

In [None]:
# text = "Hey there, how are you doing today? I hope you're having a great day! Let's make the most of it together."
char_to_idx, idx_to_char = create_vocab(text)

def merge(unicode_code_point_list: list) -> list:
    global char_to_idx, idx_to_char

    # get the dictionary of all the pairs and their occurrences
    stats = get_stats(unicode_code_point_list)

    # pair with max occurrences
    try:
        pair_with_max_occurrences = max(stats, key=stats.get)
    except ValueError:
        # if stats is empty, return the original unicode_code_point_list
        return unicode_code_point_list
    
    if stats[pair_with_max_occurrences] > 1:

        # mint a new token for the pair & new idx
        new_token = idx_to_char[pair_with_max_occurrences[0]] + idx_to_char[pair_with_max_occurrences[1]]
        new_token_idx = len(idx_to_char)
        # print(f"new token: {new_token}, new token idx: {new_token_idx}")

        # add the new token to the char_to_idx and idx_to_char dictionaries
        char_to_idx[new_token] = new_token_idx
        idx_to_char[new_token_idx] = new_token

        # replace those pairs in the byte list with the new token
        i = 0
        while i < len(unicode_code_point_list) - 1:
            if (unicode_code_point_list[i], unicode_code_point_list[i+1]) == pair_with_max_occurrences:
                unicode_code_point_list[i] = new_token_idx
                unicode_code_point_list.pop(i + 1) 
            else:
                i += 1           
    else:
        # if no pairs found, return the original unicode_code_point_list
        return unicode_code_point_list, False
        
    return unicode_code_point_list, True

def bpe(text):
    global char_to_idx, idx_to_char

    starting_vocab_size = len(idx_to_char)

    # encode text in utf-8 to get the list of unicode integers
    unicode_code_point_list = encode(text)
    raw_text_token_count = len(unicode_code_point_list)

    flag = True
    while flag:
        # continue merging until no more pairs can be merged
        unicode_code_point_list, flag = merge(unicode_code_point_list)

    processed_text_token_count = len(unicode_code_point_list)
    compression_ratio = (processed_text_token_count / raw_text_token_count) * 100
    final_vocab_size = len(idx_to_char)

    print(f"Starting vocab size: {starting_vocab_size}")
    print(f"Final vocab size: {final_vocab_size}")
    print(f"total tokens in raw text: {raw_text_token_count}")
    print(f"total tokens after bpe: {processed_text_token_count}")
    print(f"compression ratio: {compression_ratio}")
    
    return unicode_code_point_list    

decode(bpe(text)) == text

Starting vocab size: 48
Final vocab size: 135
total tokens in raw text: 671
total tokens after bpe: 312
compression ratio: 46.497764530551414


True