## Creating a Corpus for training

In [15]:
corpus = [
    "This is the first document.",
    "This document is the second document.",
    "And this is the third one.",
    "Is this the first document?",
]

print("Training Corpus:")
for doc in corpus:
    print(doc)

Training Corpus:
This is the first document.
This document is the second document.
And this is the third one.
Is this the first document?


## Creating a Vocabulary

In [16]:
unique_chars = set()
for doc in corpus:
  for char in doc:
    unique_chars.add(char)

vocab = list(unique_chars)
vocab.sort()

end_of_word = "</w>"
vocab.append(end_of_word)

print("Initial Vocabulary:")
print(vocab)
print(f"Vocabulary Size: {len(vocab)}")

Initial Vocabulary:
[' ', '.', '?', 'A', 'I', 'T', 'c', 'd', 'e', 'f', 'h', 'i', 'm', 'n', 'o', 'r', 's', 't', 'u', '</w>']
Vocabulary Size: 20


## Creating Pre-tokenized words with their respective frequencies

In [17]:
word_splits={}
for doc in corpus:
  words = doc.split(' ')
  for word in words:
    if word:
      char_list = list(word) + [end_of_word]
      word_tuple = tuple(char_list)
      if word_tuple not in word_splits:
        word_splits[word_tuple] = 0
      word_splits[word_tuple] += 1

print("\nPre-tokenized Word Frequencies:")
print(word_splits)


Pre-tokenized Word Frequencies:
{('T', 'h', 'i', 's', '</w>'): 2, ('i', 's', '</w>'): 3, ('t', 'h', 'e', '</w>'): 4, ('f', 'i', 'r', 's', 't', '</w>'): 2, ('d', 'o', 'c', 'u', 'm', 'e', 'n', 't', '.', '</w>'): 2, ('d', 'o', 'c', 'u', 'm', 'e', 'n', 't', '</w>'): 1, ('s', 'e', 'c', 'o', 'n', 'd', '</w>'): 1, ('A', 'n', 'd', '</w>'): 1, ('t', 'h', 'i', 's', '</w>'): 2, ('t', 'h', 'i', 'r', 'd', '</w>'): 1, ('o', 'n', 'e', '.', '</w>'): 1, ('I', 's', '</w>'): 1, ('d', 'o', 'c', 'u', 'm', 'e', 'n', 't', '?', '</w>'): 1}


## Helper function to get the word pairs with the highest frequencies

In [18]:
import collections

def get_pair_stats(splits):
  pair_counts = collections.defaultdict(int)
  for word_tuple, freq in splits.items():
    symbols  = list(word_tuple)
    for i in range(len(symbols) - 1):
      pair = (symbols[i], symbols[i+1])
      pair_counts[pair] += freq
  return pair_counts

## Helper function to merge the pairs with the highest frequencies

In [9]:
def merge_pair(pair_to_merge, splits):
  new_splits= {}

  (first, second) = pair_to_merge
  merged_token = first + second

  for word_tuple, freq in splits.items():
    symbols = list(word_tuple)
    new_symbols = []

    i = 0
    while i < len(symbols):

      if i < len(symbols) - 1 and symbols[i] == first and symbols[i+1] == second:
        new_symbols.append(merged_token)

        i += 2
      else:
        new_symbols.append(symbols[i])

        i += 1
    new_splits[tuple(new_symbols)] = freq
  return new_splits

## Training the tokenizer for 15 epochs

In [19]:
num_merges = 15
merges={}

current_splits = word_splits.copy()

print("\n--- Starting BPE Merges ---")
print(f"Initial Splits: {current_splits}")
print("-" * 30)

for i in range(num_merges):
  print(f"\nMerge Iteration {i+1}/{num_merges}")

  pair_stats = get_pair_stats(current_splits)
  if not pair_stats:
    print("No more pairs to merge")
    break

  sorted_pairs = sorted(pair_stats.items(), key = lambda item: item[1], reverse=True)
  print(f"Top 5 Pair Frequencies: {sorted_pairs[:5]}")


  best_pair = max(pair_stats, key = pair_stats.get)
  best_freq = pair_stats[best_pair]
  print(f'Found Best Pair: {best_pair} with Frequency: {best_freq}')


  current_splits = merge_pair(best_pair, current_splits)
  new_token = best_pair[0] + best_pair[1]
  print(f"Merging {best_pair} into '{new_token}'")
  print(f"Splits after merge: {current_splits}")


  vocab.append(new_token)
  print(f'Updated Vocabulary: {vocab}')


  merges[best_pair] = new_token
  print(f"Updated Merges: {merges}")

  print("-" * 30)


--- Starting BPE Merges ---
Initial Splits: {('T', 'h', 'i', 's', '</w>'): 2, ('i', 's', '</w>'): 3, ('t', 'h', 'e', '</w>'): 4, ('f', 'i', 'r', 's', 't', '</w>'): 2, ('d', 'o', 'c', 'u', 'm', 'e', 'n', 't', '.', '</w>'): 2, ('d', 'o', 'c', 'u', 'm', 'e', 'n', 't', '</w>'): 1, ('s', 'e', 'c', 'o', 'n', 'd', '</w>'): 1, ('A', 'n', 'd', '</w>'): 1, ('t', 'h', 'i', 's', '</w>'): 2, ('t', 'h', 'i', 'r', 'd', '</w>'): 1, ('o', 'n', 'e', '.', '</w>'): 1, ('I', 's', '</w>'): 1, ('d', 'o', 'c', 'u', 'm', 'e', 'n', 't', '?', '</w>'): 1}
------------------------------

Merge Iteration 1/15
Top 5 Pair Frequencies: [(('s', '</w>'), 8), (('i', 's'), 7), (('t', 'h'), 7), (('h', 'i'), 5), (('h', 'e'), 4)]
Found Best Pair: ('s', '</w>') with Frequency: 8
Merging ('s', '</w>') into 's</w>'
Splits after merge: {('T', 'h', 'i', 's</w>'): 2, ('i', 's</w>'): 3, ('t', 'h', 'e', '</w>'): 4, ('f', 'i', 'r', 's', 't', '</w>'): 2, ('d', 'o', 'c', 'u', 'm', 'e', 'n', 't', '.', '</w>'): 2, ('d', 'o', 'c', 'u', '

In [20]:
print("\n--- BPE Merges Complete ---")
print(f"Final Vocabulary Size: {len(vocab)}")
print("\nLearned Merges (Pair -> New Token):")
# Pretty print merges
for pair, token in merges.items():
    print(f"{pair} -> '{token}'")

print("\nFinal Word Splits after all merges:")
print(current_splits)

print("\nFinal Vocabulary (sorted):")
# Sort for consistent viewing
final_vocab_sorted = sorted(list(set(vocab))) # Use set to remove potential duplicates if any step introduced them
print(final_vocab_sorted)



--- BPE Merges Complete ---
Final Vocabulary Size: 35

Learned Merges (Pair -> New Token):
('s', '</w>') -> 's</w>'
('i', 's</w>') -> 'is</w>'
('t', 'h') -> 'th'
('th', 'e') -> 'the'
('the', '</w>') -> 'the</w>'
('d', 'o') -> 'do'
('do', 'c') -> 'doc'
('doc', 'u') -> 'docu'
('docu', 'm') -> 'docum'
('docum', 'e') -> 'docume'
('docume', 'n') -> 'documen'
('documen', 't') -> 'document'
('i', 'r') -> 'ir'
('.', '</w>') -> '.</w>'
('d', '</w>') -> 'd</w>'

Final Word Splits after all merges:
{('T', 'h', 'is</w>'): 2, ('is</w>',): 3, ('the</w>',): 4, ('f', 'ir', 's', 't', '</w>'): 2, ('document', '.</w>'): 2, ('document', '</w>'): 1, ('s', 'e', 'c', 'o', 'n', 'd</w>'): 1, ('A', 'n', 'd</w>'): 1, ('th', 'is</w>'): 2, ('th', 'ir', 'd</w>'): 1, ('o', 'n', 'e', '.</w>'): 1, ('I', 's</w>'): 1, ('document', '?', '</w>'): 1}

Final Vocabulary (sorted):
[' ', '.', '.</w>', '</w>', '?', 'A', 'I', 'T', 'c', 'd', 'd</w>', 'do', 'doc', 'docu', 'docum', 'docume', 'documen', 'document', 'e', 'f', 'h', '

## Assigning every token an ID

In [21]:
token2id = {token: idx for idx, token in enumerate(final_vocab_sorted)}
print(token2id)

{' ': 0, '.': 1, '.</w>': 2, '</w>': 3, '?': 4, 'A': 5, 'I': 6, 'T': 7, 'c': 8, 'd': 9, 'd</w>': 10, 'do': 11, 'doc': 12, 'docu': 13, 'docum': 14, 'docume': 15, 'documen': 16, 'document': 17, 'e': 18, 'f': 19, 'h': 20, 'i': 21, 'ir': 22, 'is</w>': 23, 'm': 24, 'n': 25, 'o': 26, 'r': 27, 's': 28, 's</w>': 29, 't': 30, 'th': 31, 'the': 32, 'the</w>': 33, 'u': 34}


## Helper function to apply BPE to every word.

In [22]:
def apply_bpe(word, merges, end_of_word="</w>"):
    symbols = list(word) + [end_of_word]

    while True:
        pairs = [(symbols[i], symbols[i+1]) for i in range(len(symbols) - 1)]
        pair_freqs = {pair: idx for idx, pair in enumerate(pairs) if pair in merges}

        if not pair_freqs:
            break

        best_pair = min(pair_freqs, key=lambda pair: pair_freqs[pair])

        new_symbols = []
        i = 0
        while i < len(symbols):
            if i < len(symbols) - 1 and (symbols[i], symbols[i+1]) == best_pair:
                new_symbols.append(symbols[i] + symbols[i+1])
                i += 2
            else:
                new_symbols.append(symbols[i])
                i += 1
        symbols = new_symbols
    return symbols


## Tokenizer to encode text to tokens

In [23]:
def gpt_tokenizer(corpus, merges):
  encoded_corpus_ids = []

  for doc in corpus:
      token_ids = []
      for word in doc.strip().split():
          bpe_tokens = apply_bpe(word, merges)
          for token in bpe_tokens:
              token_id = token2id.get(token)
              if token_id is not None:
                  token_ids.append(token_id)
      encoded_corpus_ids.append(token_ids)
  return encoded_corpus_ids

## Decoder to get the actual text

In [24]:
id2token = {v: k for k, v in token2id.items()}

def decode_token_ids(token_ids):
    tokens = [id2token[token_id] for token_id in token_ids]
    text = ''.join(tokens).replace('</w>', ' ').strip()  # if you use '</w>' for spaces
    return text
