In [1]:
# LLM FROM SCRATCH

In [2]:
# Building a Byte Pair Encoding (BPE) Tokenizer from scratch

In [3]:
# Step 1: Prepare training data
# Note: here we are NOT training a neural network.
# We are just creating a list of tokens from text.
#
# What we need: a text corpus.
# The tokenizer will learn "merge rules" based on how often pairs of characters appear.
#
# Example:
#   "i"   -> token 1
#   "s"   -> token 2
#   "is"  -> token 3
#
# How it works:
# 1. Start with text split into individual characters.
#    (every character is always in the vocabulary)
# 2. Count how often character pairs occur.
# 3. Merge the most frequent pairs into new tokens (subwords).
# 4. Over time, the tokenizer builds a vocabulary that mixes
#    single characters + useful subword tokens.
#
# Important:
# - Characters always remain as fallback tokens.
# - Subwords get priority when tokenizing, making it more efficient.
#
# Concretely:
# - We scan the text and count how many times each pair of characters appears.
# - For example, if the pair "is" appears very often, we create a new token for it.
# - This reduces computation: instead of processing "i" and "s" separately,
#   we can treat "is" as a single token.
# - Note: "i" and "s" still remain in the vocabulary as individual tokens,
#   but whenever the pair "is" exists, it takes priority over the single characters.
#
# Iterative merges:
# - The process continues on top of previously created tokens.
# - For example, if "is" was already merged into a token, and we notice "his"
#   appears frequently, we merge "h" + "is" → "his".
# - Next, if "this" is common, we merge "t" + "his" → "this".
# - This way, the vocabulary gradually grows from characters → subwords → whole words,
#   depending on frequency in the training text.

In [4]:
# Our corpus of data
corpus = [
    "This is the first document.",
    "This document is the second document.",
    "And this is the third one.",
    "Is this the first document?",
]

In [5]:
print("Corpus: ")
for doc in corpus:
    print(doc)

Corpus: 
This is the first document.
This document is the second document.
And this is the third one.
Is this the first document?


In [6]:
# Step 2: Initialize vocabulary with unique characters
#
# The first version of our vocabulary is simply all the unique characters
# that appear in the training corpus.
# Each character will be treated as an initial token.
#
# In addition, we add a special end-of-word marker (</w>).
# This marker helps the tokenizer know where words end, so that
# frequent whole words or subwords can be merged properly later.
#
# Example:
#   "this"  →  ["t", "h", "i", "s</w>"]
#   "is"    →  ["i", "s</w>"]
#
unique_chars = set()
for doc in corpus:
    for char in doc:
        unique_chars.add(char)

# Convert to a sorted list so the vocabulary is consistent and reproducible
vocab = list(unique_chars)
vocab.sort()

# Add the special end-of-word token
end_of_word = "</w>"
vocab.append(end_of_word)

In [7]:
print("Initial Vocabulary:")
print(vocab)
print(f"Vocabulary Size: {len(vocab)}")

Initial Vocabulary:
[' ', '.', '?', 'A', 'I', 'T', 'c', 'd', 'e', 'f', 'h', 'i', 'm', 'n', 'o', 'r', 's', 't', 'u', '</w>']
Vocabulary Size: 20


In [8]:
# Step 3: Pre-tokenize the corpus
#
# Goal:
# - Split text into words (by spaces, for simplicity).
# - Break each word into its characters.
# - Add the special end-of-word token (</w>) at the end of every word.
#
# Why?
# - This gives us the initial representation of words as sequences of characters.
# - Example: "This" → ("T", "h", "i", "s", "</w>")
#
# Implementation details:
# - We store each word as a tuple of characters (immutable).
#   Tuples can be used as dictionary keys, unlike lists.
# - We count how many times each word (as a sequence of characters) appears
#   in the whole corpus.
#
# Note:
# - Adding the </w> token ensures that subwords are learned within word
#   boundaries. For example:
#   "document" → ("d", "o", "c", "u", "m", "e", "n", "t", "</w>")
#   This way, if "doc" becomes a frequent subword, it is clear that it
#   belongs inside the word "document" and not across words.
#
word_splits = {}
for doc in corpus:
    words = doc.split(' ')
    for word in words:
        if word:
            # Represent word as characters + </w>
            char_list = list(word) + [end_of_word]
            word_tuple = tuple(char_list)
            
            # Count frequency of this word form
            if word_tuple not in word_splits:
                word_splits[word_tuple] = 0
            word_splits[word_tuple] += 1

print("\nPre-tokenized Word Frequencies:")
print(word_splits)


Pre-tokenized Word Frequencies:
{('T', 'h', 'i', 's', '</w>'): 2, ('i', 's', '</w>'): 3, ('t', 'h', 'e', '</w>'): 4, ('f', 'i', 'r', 's', 't', '</w>'): 2, ('d', 'o', 'c', 'u', 'm', 'e', 'n', 't', '.', '</w>'): 2, ('d', 'o', 'c', 'u', 'm', 'e', 'n', 't', '</w>'): 1, ('s', 'e', 'c', 'o', 'n', 'd', '</w>'): 1, ('A', 'n', 'd', '</w>'): 1, ('t', 'h', 'i', 's', '</w>'): 2, ('t', 'h', 'i', 'r', 'd', '</w>'): 1, ('o', 'n', 'e', '.', '</w>'): 1, ('I', 's', '</w>'): 1, ('d', 'o', 'c', 'u', 'm', 'e', 'n', 't', '?', '</w>'): 1}


In [9]:
# Step 4: Count symbol pair frequencies
#
# Goal:
# - Take the dictionary of word splits we created (word_splits).
# - For each word, look at all adjacent pairs of symbols.
# - Count how many times each pair appears across the entire corpus.
#
# Example:
#   Input: {("T", "h", "i", "s", "</w>"): 2}
#   Output: {("T", "h"): 2, ("h", "i"): 2, ("i", "s"): 2, ("s", "</w>"): 2}

import collections 

def get_pair_stats(splits):
    """
    Count the frequency of adjacent symbol pairs in the word_splits dictionary.
    Example:
        {('T', 'h', 'i', 's', '</w>'): 2, ('i', 's', '</w>'): 3, ...}
    Output:
        {('T', 'h'): 2, ('h', 'i'): 2, ('i', 's'): 5, ...}
    """

    pair_counts = collections.defaultdict(int)
    # there is a difference between the normal dict of python and collections.defaultdict 
    # the difference is: in a normal dict if we try to call a key that doesn't exist,
    # it will throw an error. 
    # but this dictionary will create that key automatically,
    # and the value it assigns to that key will be whatever we pass to the function "int",
    # in this case it assigns zero.

    for word_tuple, freq in splits.items():
        # Example: ('T', 'h', 'i', 's', '</w>'): 2
        # word_tuple = ('T', 'h', 'i', 's', '</w>')
        # freq = 2 
        symbols = list(word_tuple)
        # symbols = ['T', 'h', 'i', 's', '</w>']
        for i in range(len(symbols) - 1):
            # len(symbols) - 1 ensures the last pair doesn't go out of range 
            # because we are accessing symbols[i+1]
            pair = (symbols[i], symbols[i+1])
            # Example: pair = ('T', 'h')
            pair_counts[pair] += freq  # adding the frequency of the pair
            # Example: pair_counts = {('T', 'h'): freq}
    return pair_counts


In [11]:
# Step 5: Merge the most frequent pair
#
# Goal:
# - Take the most frequent pair of symbols (from pair_counts).
# - Go through every word in word_splits.
# - Replace occurrences of that pair with a new merged token.
# - Keep track of merges (so we could undo or review them later if needed).
#
# Example:
#   Input pair_to_merge: ('i', 's')
#   Input word_splits: {
#       ('T', 'h', 'i', 's', '</w>'): 2,
#       ('i', 's', '</w>'): 3,
#       ...
#   }
#   Output new_splits: {
#       ('T', 'h', 'is', '</w>'): 2,
#       ('is', '</w>'): 3,
#       ...
#   }
#
def merge_pair(pair_to_merge, splits):
    """Merges the specified pair in the word splits."""
    new_splits = {} # empty dictionary 
    (first, second) = pair_to_merge # ('i','s') → first = 'i', second = 's'
    merged_token = first + second   # merged_token = 'is' (a string)

    for word_tuple, freq in splits.items():
        # Example: word_tuple = ('T', 'h', 'i', 's', '</w>'), freq = 2
        symbols = list(word_tuple)
        # symbols = ['T', 'h', 'i', 's', '</w>']
        new_symbols = []
        i = 0
        while i < len(symbols):
            if i < len(symbols) - 1 and symbols[i] == first and symbols[i+1] == second:
                new_symbols.append(merged_token)
                # if symbols[i] and symbols[i+1] match the pair_to_merge,
                # replace them with the merged token
                i += 2 
            else:
                # otherwise, just add the current character
                new_symbols.append(symbols[i])
                i += 1
        # store the updated sequence of symbols as a tuple, with the same frequency
        new_splits[tuple(new_symbols)] = freq
    return new_splits

In [12]:
# Step 6: Quick recap of our data structures (with examples from our corpus)
#
# Corpus (toy example):
# [
#   "This is the first document.",
#   "This document is the second document.",
#   "And this is the third one.",
#   "Is this the first document?",
# ]
#
# 1) vocab
#    - Definition: list of ALL unique characters found in the corpus, plus the end-of-word marker </w>.
#    - Example (one possible order, after sorting and then appending </w>):
#      [' ', '.', '?', 'A', 'I', 'T', 'c', 'd', 'e', 'f', 'h', 'i', 'm', 'n', 'o', 'r', 's', 't', 'u', '</w>']
#    - Purpose: this is our initial token set (characters).
#
# 2) word_splits
#    - Definition: a dictionary mapping each WORD (as a tuple of characters + '</w>') to its frequency in the corpus.
#      e.g., { ('T','h','i','s','</w>'): 2, ('i','s','</w>'): 3, ... }
#    - How it’s built: split each line by spaces → take each word → turn into list(word) + ['</w>'] → tuple(...) as key.
#    - Notes on punctuation and case:
#        * Punctuation stays inside the word (e.g., "document." becomes ('d','o','c','u','m','e','n','t','.', '</w>')).
#        * Case is preserved (e.g., "This" vs "this" are different keys).
#    - Concrete examples from our corpus:
#        * ('T','h','i','s','</w>') appears 2 times  → "This" (line 1 and 2)
#        * ('t','h','i','s','</w>') appears 2 times  → "this" (line 3 and 4)
#        * ('i','s','</w>') appears 3 times         → "is"   (lines 1, 2, 3)
#
# 3) get_pair_state(splits)
#    - Input: the word_splits dictionary.
#    - Output: a dictionary that counts ADJACENT symbol pairs INSIDE each word (never across word boundaries).
#      Example output shape: { ('T','h'): 2, ('h','i'): 4, ('i','s'): 7, ('s','</w>'): 7, ... }
#    - How it works:
#        * For each word tuple like ('T','h','i','s','</w>') with freq=2,
#          it adds counts for ('T','h'), ('h','i'), ('i','s'), ('s','</w>') each +2.
#        * It never considers pairs across words (e.g., it does NOT pair the last char of one word with the first of the next).
#    - Concrete intuition with our corpus:
#        * ('i','s') appears inside:
#            - "This" (upper T) → 2 times total
#            - "this" (lower t) → 2 times total
#            - "is" as a word   → 3 times total (pairs: ('i','s') and ('s','</w>'))
#          So ('i','s') can easily sum to 2 + 2 + 3 = 7 in the pair counts.
#
# 4) merge_pair(pair_to_merge, splits)
#    - Input:
#        * pair_to_merge: a tuple like ('i','s') — typically the MOST frequent pair from get_pair_state.
#        * splits: the current word_splits dictionary.
#    - Process:
#        * For every word tuple, find every occurrence of the adjacent pair ('i','s') and replace it with the merged token "is".
#        * Return a NEW dictionary (same shape as word_splits) but with updated word tuples that include the merged token.
#    - Output shape (same as word_splits):
#        * Before:
#            {
#              ('T','h','i','s','</w>'): 2,
#              ('i','s','</w>'): 3,
#              ...
#            }
#        * After merging ('i','s') → "is":
#            {
#              ('T','h','is','</w>'): 2,
#              ('is','</w>'): 3,
#              ...
#            }

In [13]:
# Step 7: Iterative BPE merging loop
#
# Goal:
# - Repeatedly find the most frequent adjacent symbol pair in the corpus (within words),
#   merge that pair into a new token, and update our splits/vocabulary/merge rules.
# - Do this up to `num_merges` times, or stop early if no more pairs exist.
#
# Key variables (with concrete examples from our toy corpus):
# - num_merges: the maximum number of merge steps to perform (e.g., 15).
# - merges: a dictionary storing learned merge rules. Example after a few steps:
#       { ('i','s'): 'is', ('t','h'): 'th' }
#   This means we learned to merge 'i' + 's' → "is", and 't' + 'h' → "th".
# - current_splits: the working copy of word_splits that we keep updating after each merge.
#   Example (before any merges):
#       {
#         ('T','h','i','s','</w>'): 2,
#         ('i','s','</w>'): 3,
#         ('t','h','e','</w>'): 4,
#         ('d','o','c','u','m','e','n','t','.','</w>'): 3,
#         ...
#       }
#
# Functions used:
# - get_pair_stats(splits) → dict:
#       Input: current_splits (dict of word tuples → frequency).
#       Output: dictionary mapping adjacent symbol pairs to total frequency in the corpus.
#       Example:
#           { ('i','s'): 7, ('s','</w>'): 8, ('t','h'): 7, ('h','i'): 5, ('T','h'): 2, ... }
# - merge_pair(pair_to_merge, splits) → dict:
#       Input: a pair like ('i','s') and current_splits.
#       Process: replaces every occurrence of that pair inside each word tuple with the merged token ("is").
#       Output: a new splits dict with updated word tuples.
#       Example (effect on entries):
#           Before: ('i','s','</w>') → After: ('is','</w>')
#           Before: ('T','h','i','s','</w>') → After: ('T','h','is','</w>')
#
num_merges = 15

# Stores merge rules, e.g., {('a','b'):'ab'}, {('T','h'):'Th'}
merges = {}

# Start from the initial word_splits (words → frequency) and update it iteratively.
current_splits = word_splits.copy()

print("\n--- Starting BPE Merges ---")
print(f"Initial Splits: {current_splits}")
print("-" * 30)

for i in range(num_merges):
    print(f"\nMerge Iteration {i+1}/{num_merges}")

    # 1) Calculate pair frequencies across all words in current_splits.
    #    pair_stats is a dict mapping adjacent pairs → total frequency.
    #    Example shape:
    #      { ('T','h'): 2, ('h','i'): 5, ('i','s'): 7, ('s','</w>'): 8,
    #        ('t','h'): 7, ('h','e'): 4, ('e','</w>'): 4, ('f','i'): 2, ... }
    pair_stats = get_pair_stats(current_splits)

    # If there are no adjacent pairs left (e.g., all words are single tokens), we stop.
    if not pair_stats:
        print("No more pairs to merge.")
        break

    # 2) Inspect the top pairs (optional, for debugging/understanding).
    #    Explanation:
    #    - pair_stats.items() produces a list-like view of (pair, freq) tuples.
    #    - key=lambda item: item[1] sorts by the frequency value.
    #    - reverse=True sorts from highest to lowest frequency.
    #    Result example:
    #      [(('s','</w>'), 8), (('i','s'), 7), (('t','h'), 7), (('h','i'), 5), (('T','h'), 2)]
    sorted_pairs = sorted(pair_stats.items(), key=lambda item: item[1], reverse=True)
    print(f"Top 5 Pair Frequencies: {sorted_pairs[:5]}")

    # 3) Select the single most frequent pair.
    #    max will iterate over the keys of pair_stats; key=pair_stats.get compares by the dict value (frequency).
    #    Example over { ('i','s'):7, ('s','</w>'):8, ('h','i'):5 } → best_pair = ('s','</w>')
    best_pair = max(pair_stats, key=pair_stats.get)
    best_freq = pair_stats[best_pair]  # (equivalently: pair_stats.get(best_pair))
    print(f"Found Best Pair: {best_pair} with Frequency: {best_freq}")

    # 4) Merge that best pair in every word where it appears.
    #    Example:
    #      best_pair = ('i','s') → new_token = "is"
    #      Before: ('i','s','</w>')  → After: ('is','</w>')
    #      Before: ('T','h','i','s','</w>') → After: ('T','h','is','</w>')
    current_splits = merge_pair(best_pair, current_splits)
    new_token = best_pair[0] + best_pair[1]
    print(f"Merging {best_pair} into '{new_token}'")
    print(f"Splits after merge: {current_splits}")

    # 5) Update vocabulary: add the newly created token (e.g., "is", "th", "doc").
    #    This grows our token set from characters → subwords → common words.
    vocab.append(new_token)
    print(f"Updated Vocabulary: {vocab}")

    # 6) Record the merge rule so we can reproduce tokenization later on new text.
    #    Example: merges[('i','s')] = 'is'
    merges[best_pair] = new_token
    print(f"Updated Merges: {merges}")
    print("-" * 30)



--- Starting BPE Merges ---
Initial Splits: {('T', 'h', 'i', 's', '</w>'): 2, ('i', 's', '</w>'): 3, ('t', 'h', 'e', '</w>'): 4, ('f', 'i', 'r', 's', 't', '</w>'): 2, ('d', 'o', 'c', 'u', 'm', 'e', 'n', 't', '.', '</w>'): 2, ('d', 'o', 'c', 'u', 'm', 'e', 'n', 't', '</w>'): 1, ('s', 'e', 'c', 'o', 'n', 'd', '</w>'): 1, ('A', 'n', 'd', '</w>'): 1, ('t', 'h', 'i', 's', '</w>'): 2, ('t', 'h', 'i', 'r', 'd', '</w>'): 1, ('o', 'n', 'e', '.', '</w>'): 1, ('I', 's', '</w>'): 1, ('d', 'o', 'c', 'u', 'm', 'e', 'n', 't', '?', '</w>'): 1}
------------------------------

Merge Iteration 1/15
Top 5 Pair Frequencies: [(('s', '</w>'), 8), (('i', 's'), 7), (('t', 'h'), 7), (('h', 'i'), 5), (('h', 'e'), 4)]
Found Best Pair: ('s', '</w>') with Frequency: 8
Merging ('s', '</w>') into 's</w>'
Splits after merge: {('T', 'h', 'i', 's</w>'): 2, ('i', 's</w>'): 3, ('t', 'h', 'e', '</w>'): 4, ('f', 'i', 'r', 's', 't', '</w>'): 2, ('d', 'o', 'c', 'u', 'm', 'e', 'n', 't', '.', '</w>'): 2, ('d', 'o', 'c', 'u', '

In [14]:
# Step 8: Review Final Results
#
#
# Example of what to expect (simplified):
# - merges:
#     { ('i','s'):'is', ('t','h'):'th', ('Th','is'):'This', ... }
# - current_splits:
#     { ('This','</w>'):2, ('is','</w>'):3, ('doc','ument','.</w>'):3, ... }
# - vocab:
#     [' ', '.', '?', 'A', 'I', 'T', ..., 'is', 'th', 'This', 'doc', ...]
#

# --- BPE Merges Complete ---
print("\n--- BPE Merges Complete ---")

# Size of the final vocabulary
print(f"Final Vocabulary Size: {len(vocab)}")

# 1) Learned merge rules (pair → new token)
print("\nLearned Merges (Pair -> New Token):")
for pair, token in merges.items():
    # Example: ('i','s') -> 'is'
    print(f"{pair} -> '{token}'")

# 2) Final representation of words (after all merges)
print("\nFinal Word Splits after all merges:")
# Example: ('This','</w>'):2, ('is','</w>'):3, ...
print(current_splits)

# 3) Final Vocabulary
print("\nFinal Vocabulary (sorted):")
# Convert to set in case duplicates slipped in, then sort for consistent view
final_vocab_sorted = sorted(list(set(vocab)))
print(final_vocab_sorted)



--- BPE Merges Complete ---
Final Vocabulary Size: 35

Learned Merges (Pair -> New Token):
('s', '</w>') -> 's</w>'
('i', 's</w>') -> 'is</w>'
('t', 'h') -> 'th'
('th', 'e') -> 'the'
('the', '</w>') -> 'the</w>'
('d', 'o') -> 'do'
('do', 'c') -> 'doc'
('doc', 'u') -> 'docu'
('docu', 'm') -> 'docum'
('docum', 'e') -> 'docume'
('docume', 'n') -> 'documen'
('documen', 't') -> 'document'
('i', 'r') -> 'ir'
('.', '</w>') -> '.</w>'
('d', '</w>') -> 'd</w>'

Final Word Splits after all merges:
{('T', 'h', 'is</w>'): 2, ('is</w>',): 3, ('the</w>',): 4, ('f', 'ir', 's', 't', '</w>'): 2, ('document', '.</w>'): 2, ('document', '</w>'): 1, ('s', 'e', 'c', 'o', 'n', 'd</w>'): 1, ('A', 'n', 'd</w>'): 1, ('th', 'is</w>'): 2, ('th', 'ir', 'd</w>'): 1, ('o', 'n', 'e', '.</w>'): 1, ('I', 's</w>'): 1, ('document', '?', '</w>'): 1}

Final Vocabulary (sorted):
[' ', '.', '.</w>', '</w>', '?', 'A', 'I', 'T', 'c', 'd', 'd</w>', 'do', 'doc', 'docu', 'docum', 'docume', 'documen', 'document', 'e', 'f', 'h', '

In [15]:
# Note on real-world tokenizers:
#
# In practice, tokenizers do not work directly with characters or subword strings.
# Instead, each element of the vocabulary is mapped to a unique integer index.
#
# Why?
# - Numbers are much faster for a computer to process than strings.
# - Looking up an index in a table is more efficient than comparing text symbols.
# - Integers take less memory to store than characters or strings.
#
# Example of a vocabulary-to-index mapping:
# {
#     "!": 0,
#     "\"": 1,
#     "#": 2,
#     "$": 3,
#     "%": 4,
#     "&": 5,
#     "'": 6,
#     "(": 7,
#     ")": 8,
#     "*": 9,
#     "+": 10,
#     ...
# }
#
# With this setup:
# - Each token (character, subword, or word) is represented by an integer ID.
# - Any input text can be quickly converted to a sequence of numbers.
# - This sequence of numbers is what is actually fed into the embedding layer of an LLM.
