<a href="https://colab.research.google.com/github/anjalirj27/Llama4/blob/main/tokenizer_code.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!git clone https://github.com/anjalirj27/Llama4.git
%cd Llama4

Cloning into 'Llama4'...
/content/Llama4/Llama4/Llama4/Llama4/Llama4


In [None]:
corpus = [
    "This is the first documnet.",
    "This document is the second document.",
    "And this is the third one.",
    "Is this the first document?",
]

In [None]:
print("Training Corpus:")
for doc in corpus:
  print(doc)

Training Corpus:
This is the first documnet.
This document is the second document.
And this is the third one.
Is this the first document?


In [None]:
unique_chars = set()
for doc in corpus:
  for char in doc:
    unique_chars.add(char)

In [None]:
vocab = list(unique_chars)

In [None]:
print(vocab)

['.', 'o', 'h', 't', 'I', '?', 'e', 'i', 'u', 'f', 'n', ' ', 'A', 's', 'c', 'r', 'T', 'm', 'd']


In [None]:
vocab.sort()
print(vocab)

[' ', '.', '?', 'A', 'I', 'T', 'c', 'd', 'e', 'f', 'h', 'i', 'm', 'n', 'o', 'r', 's', 't', 'u']


In [None]:
end_of_word = "</w"
vocab.append(end_of_word)

In [None]:
print("Initial Vocabulary:")
print(vocab)
print(f"Vocabulary Size: {len(vocab)}")

Initial Vocabulary:
[' ', '.', '?', 'A', 'I', 'T', 'c', 'd', 'e', 'f', 'h', 'i', 'm', 'n', 'o', 'r', 's', 't', 'u', '</w']
Vocabulary Size: 20


In [None]:
word_splits = {}
for doc in corpus:
  words = doc.split(' ')
  for word in words:
    if word:
      char_list = list(word) + [end_of_word]
      word_tuple = tuple(char_list)
      if word_tuple not in word_splits:
        word_splits[word_tuple] = 0
      word_splits[word_tuple] += 1

print("\nPre-tokenized Word Frequencies")
print(word_splits)


Pre-tokenized Word Frequencies
{('T', 'h', 'i', 's', '</w'): 2, ('i', 's', '</w'): 3, ('t', 'h', 'e', '</w'): 4, ('f', 'i', 'r', 's', 't', '</w'): 2, ('d', 'o', 'c', 'u', 'm', 'n', 'e', 't', '.', '</w'): 1, ('d', 'o', 'c', 'u', 'm', 'e', 'n', 't', '</w'): 1, ('s', 'e', 'c', 'o', 'n', 'd', '</w'): 1, ('d', 'o', 'c', 'u', 'm', 'e', 'n', 't', '.', '</w'): 1, ('A', 'n', 'd', '</w'): 1, ('t', 'h', 'i', 's', '</w'): 2, ('t', 'h', 'i', 'r', 'd', '</w'): 1, ('o', 'n', 'e', '.', '</w'): 1, ('I', 's', '</w'): 1, ('d', 'o', 'c', 'u', 'm', 'e', 'n', 't', '?', '</w'): 1}


In [None]:
word_tuple

('d', 'o', 'c', 'u', 'm', 'e', 'n', 't', '?', '</w')

In [None]:
import collections

In [None]:
def get_pair_stats(splits):
  """Counts the frequency of adjacent pairs in the word splits."""
  pair_counts = collections.defaultdict(int)
  for word_tuple, freq in splits.items():
    symbols = list(word_tuple)
    for i in range(len(symbols) - 1):
      pair = (symbols[i], symbols[i+1])
      pair_counts[pair] += freq
    return pair_counts

In [None]:
def merge_pair(pair_to_merge, splits):
  """Merges the specified pair in the word splits """
  new_splits = {}
  (first, second) = pair_to_merge
  merged_token = first + second
  for word_tuple, freq in splits.items():
    symbols = list(word_tuple)
    new_symbols = []
    i = 0
    while i< len(symbols):
      #if the current and next symbol match and the pair to merge
      if i < len(symbols) - 1 and symbols[i] == first and symbols[i+1] == second:
        new_symbols.append(merged_token)
        i+=2 #skip the next symbol
      else:
        new_symbols.append(symbols[i])
        i+= 1
      new_splits[tuple(new_symbols)] = freq #use updated symbol list as the key
    return new_splits

In [None]:
num_merges = 15

In [None]:
merges = {}

In [None]:
current_splits = word_splits.copy()

In [None]:
print("\n------------> Starting BPE Merges --------->")
print(f"Initial Word Splits: {current_splits}")
print("-"*35)


------------> Starting BPE Merges --------->
Initial Word Splits: {('T', 'h', 'i', 's', '</w'): 2, ('i', 's', '</w'): 3, ('t', 'h', 'e', '</w'): 4, ('f', 'i', 'r', 's', 't', '</w'): 2, ('d', 'o', 'c', 'u', 'm', 'n', 'e', 't', '.', '</w'): 1, ('d', 'o', 'c', 'u', 'm', 'e', 'n', 't', '</w'): 1, ('s', 'e', 'c', 'o', 'n', 'd', '</w'): 1, ('d', 'o', 'c', 'u', 'm', 'e', 'n', 't', '.', '</w'): 1, ('A', 'n', 'd', '</w'): 1, ('t', 'h', 'i', 's', '</w'): 2, ('t', 'h', 'i', 'r', 'd', '</w'): 1, ('o', 'n', 'e', '.', '</w'): 1, ('I', 's', '</w'): 1, ('d', 'o', 'c', 'u', 'm', 'e', 'n', 't', '?', '</w'): 1}
-----------------------------------


In [None]:
for i in range(num_merges):
  print(f"\nMerge Iteration {i+1}/{num_merges}")

  #calculate pair frequencies ---> Step 1.
  pair_stats = get_pair_stats(current_splits)
  if not pair_stats:
    break
  #Print top 7 pairs for inspection (Optional)
  sorted_pairs = sorted(pair_stats.items(), key = lambda item: item[1], reverse = True)
  print(f"Top 7 pairs: {sorted_pairs[:7]}")


Merge Iteration 1/15
Top 7 pairs: [(('T', 'h'), 2), (('h', 'i'), 2), (('i', 's'), 2), (('s', '</w'), 2)]

Merge Iteration 2/15
Top 7 pairs: [(('T', 'h'), 2), (('h', 'i'), 2), (('i', 's'), 2), (('s', '</w'), 2)]

Merge Iteration 3/15
Top 7 pairs: [(('T', 'h'), 2), (('h', 'i'), 2), (('i', 's'), 2), (('s', '</w'), 2)]

Merge Iteration 4/15
Top 7 pairs: [(('T', 'h'), 2), (('h', 'i'), 2), (('i', 's'), 2), (('s', '</w'), 2)]

Merge Iteration 5/15
Top 7 pairs: [(('T', 'h'), 2), (('h', 'i'), 2), (('i', 's'), 2), (('s', '</w'), 2)]

Merge Iteration 6/15
Top 7 pairs: [(('T', 'h'), 2), (('h', 'i'), 2), (('i', 's'), 2), (('s', '</w'), 2)]

Merge Iteration 7/15
Top 7 pairs: [(('T', 'h'), 2), (('h', 'i'), 2), (('i', 's'), 2), (('s', '</w'), 2)]

Merge Iteration 8/15
Top 7 pairs: [(('T', 'h'), 2), (('h', 'i'), 2), (('i', 's'), 2), (('s', '</w'), 2)]

Merge Iteration 9/15
Top 7 pairs: [(('T', 'h'), 2), (('h', 'i'), 2), (('i', 's'), 2), (('s', '</w'), 2)]

Merge Iteration 10/15
Top 7 pairs: [(('T', 'h

In [None]:
#Find best pair ---> Step 2.
best_pair = max(pair_stats, key = pair_stats.get)
best_freq = pair_stats[best_pair]
print(f"Found Best Pair: {best_pair} with frequency: {best_freq}")

Found Best Pair: ('T', 'h') with frequency: 2


In [None]:
current_splits

{('T', 'h', 'i', 's', '</w'): 2,
 ('i', 's', '</w'): 3,
 ('t', 'h', 'e', '</w'): 4,
 ('f', 'i', 'r', 's', 't', '</w'): 2,
 ('d', 'o', 'c', 'u', 'm', 'n', 'e', 't', '.', '</w'): 1,
 ('d', 'o', 'c', 'u', 'm', 'e', 'n', 't', '</w'): 1,
 ('s', 'e', 'c', 'o', 'n', 'd', '</w'): 1,
 ('d', 'o', 'c', 'u', 'm', 'e', 'n', 't', '.', '</w'): 1,
 ('A', 'n', 'd', '</w'): 1,
 ('t', 'h', 'i', 's', '</w'): 2,
 ('t', 'h', 'i', 'r', 'd', '</w'): 1,
 ('o', 'n', 'e', '.', '</w'): 1,
 ('I', 's', '</w'): 1,
 ('d', 'o', 'c', 'u', 'm', 'e', 'n', 't', '?', '</w'): 1}

In [None]:
#Merge the Best Pair -->Step 3.
current_splits = merge_pair(best_pair, current_splits)
new_token = best_pair[0] + best_pair[1]
print(f"Merging {best_pair} into '{new_token}'")
print(f"Splits after merge: {current_splits}")

Merging ('T', 'h') into 'Th'
Splits after merge: {('Th',): 2, ('Th', 'i'): 2, ('Th', 'i', 's'): 2, ('Th', 'i', 's', '</w'): 2}


In [None]:
#Vocab Updation ---->Step 3.
vocab.append(new_token)
print(f"Updated Vocabulary: {vocab}")

Updated Vocabulary: [' ', '.', '?', 'A', 'I', 'T', 'c', 'd', 'e', 'f', 'h', 'i', 'm', 'n', 'o', 'r', 's', 't', 'u', '</w', 'Th']


In [None]:
#Store Merge Rule ---->Step 5
merges[best_pair] = new_token
print(f"Updated Merges: {merges}")
print('-'*50)

Updated Merges: {('T', 'h'): 'Th'}
--------------------------------------------------


In [None]:
print("\n-------------- BPE Merges Completed ---------------------")
print(f"Final Vocabulary Size: {len(vocab)}")


-------------- BPE Merges Completed ---------------------
Final Vocabulary Size: 21


In [None]:
#Pretty print merges
for pair, token in merges.items():
  print(f"{pair} -> '{token}")

('T', 'h') -> 'Th


In [None]:
print("\nFinal Word Splits after all merges:")
print(current_splits)


Final Word Splits after all merges:
{('Th',): 2, ('Th', 'i'): 2, ('Th', 'i', 's'): 2, ('Th', 'i', 's', '</w'): 2}


In [None]:
print("\nFinal vcabulary (Sorted):")
final_vocab_sorted = sorted(list(set(vocab)))
print(final_vocab_sorted)


Final vcabulary (Sorted):
[' ', '.', '</w', '?', 'A', 'I', 'T', 'Th', 'c', 'd', 'e', 'f', 'h', 'i', 'm', 'n', 'o', 'r', 's', 't', 'u']
