While editing this notebook, don't change cell types as that confuses the autograder.

Before you turn this notebook in, make sure everything runs as expected. First, **restart the kernel** (in the menubar, select Kernel $\rightarrow$ Restart) and then **run all cells** (in the menubar, select Cell $\rightarrow$ Run All).

Make sure you fill in any place that says `YOUR CODE HERE` or "YOUR ANSWER HERE", as well as your name below:

In [None]:
NAME = ""

_Understanding Deep Learning_

---

<a href="https://colab.research.google.com/github/DL4DS/sp2024_notebooks/blob/main/release/nbs12/12_3_Tokenization.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Notebook 12.3: Tokenization

This notebook builds set of tokens from a text string as in figure 12.8 of the book,
employing a simplified version of the byte pair encoding algorithm.

In [None]:
import re, collections

In [None]:
text = "a sailor went to sea sea sea "+\
                  "to see what he could see see see "+\
                  "but all that he could see see see "+\
                  "was the bottom of the deep blue sea sea sea"

In [None]:
print(text)
print(len(text))

## Helper Functions

To begin with, the tokens are the individual letters and the </w> whitespace token. To visualize this, we represent each word in terms of these tokens with spaces between the tokens to delineate them.

The tokenized text is stored in a structure that represents each word as tokens together with the count of how often that word occurs.  We'll call this the *vocabulary*.

In [None]:
def initialize_vocabulary(text):
  """
  Initialize the vocabulary for the BPE algorithm.
   
  Initialize the vocabulary from the text string we will be using to 'train' 
  the algorithm. The vocabulary is a dictionary.  The keys are the words in the
  text and the values are the frequency of the word in the text.

  Returns:
  vocab: A dictionary where the keys are the words in the text and the values
          are the frequency of the word in the text.
  """
  vocab = collections.defaultdict(int)
  words = text.strip().split()
  for word in words:
      vocab[' '.join(list(word)) + ' </w>'] += 1
  return vocab

In [None]:
vocab = initialize_vocabulary(text)
print('Vocabulary: {}'.format(vocab))
print('Size of vocabulary: {}'.format(len(vocab)))

Find all the tokens in the current vocabulary and their frequencies

In [None]:
def get_tokens_and_frequencies(vocab):
  """
  Get the tokens and their frequencies from the vocabulary.
  
  Returns:
  tokens: A dictionary where the keys are the tokens and the values are the
          frequency of the token in the vocabulary.
  """
  tokens = collections.defaultdict(int)
  for word, freq in vocab.items():
      word_tokens = word.split()
      for token in word_tokens:
          tokens[token] += freq
  return tokens

In [None]:
tokens = get_tokens_and_frequencies(vocab)
print('Tokens: {}'.format(tokens))
print('Number of tokens: {}'.format(len(tokens)))

Find each pair of adjacent tokens in the vocabulary
and count them.  We will subsequently merge the most frequently occurring pair.

In [None]:
def get_pairs_and_counts(vocab):
    """
    Get the pairs of symbols and their counts from the vocabulary.
    
    Returns:
    pairs: A dictionary where the keys are the pairs of symbols and the values
            are the frequency of the pair in the vocabulary.
    """
    pairs = collections.defaultdict(int)
    for word, freq in vocab.items():
        symbols = word.split()
        for i in range(len(symbols)-1):
            pairs[symbols[i],symbols[i+1]] += freq
    return pairs

In [None]:
pairs = get_pairs_and_counts(vocab)
print(f'Pairs: {pairs}')
print(f'Number of distinct pairs: {len(pairs)}')

most_frequent_pair = max(pairs, key=pairs.get)
print(f'Most frequent pair: {most_frequent_pair}')

Merge the instances of the best pair in the vocabulary

In [None]:
def merge_pair_in_vocabulary(pair, vocab_in):
    """
    Merge a pair of symbols in the vocabulary.
    
    Returns:
    vocab_out: A dictionary where the keys are the updated tokens in the 
               vocabulary and the values are the frequency of the token in the
               vocabulary.
    """
    vocab_out = {}
    bigram = re.escape(' '.join(pair))
    p = re.compile(r'(?<!\S)' + bigram + r'(?!\S)')
    for word in vocab_in:
        word_out = p.sub(''.join(pair), word)
        vocab_out[word_out] = vocab_in[word]
    return vocab_out

In [None]:
vocab = merge_pair_in_vocabulary(most_frequent_pair, vocab)
print(f'Vocabulary: {vocab}')
print(f'Size of vocabulary: {len(vocab)}')

Update the tokens, which now include the best token 'se'

In [None]:
tokens = get_tokens_and_frequencies(vocab)
print(f'Tokens: {tokens}')
print(f'Number of tokens: {tokens}')

## Byte Pair Encoding

Now let's write the full tokenization routine

In [None]:
# TODO -- write this routine by filling in this missing parts,
# calling the above routines
def tokenize(text, num_merges):
  # Initialize the vocabulary from the input text
  # vocab = (your code here)
  vocab = initialize_vocabulary(text)

  for i in range(num_merges):
    # Find the tokens and how often they occur in the vocabulary
    # tokens = (your code here)
    # YOUR CODE HERE
    raise NotImplementedError()

    # Find the pairs of adjacent tokens and their counts
    # pairs = (your code here)
    # YOUR CODE HERE
    raise NotImplementedError()

    # Find the most frequent pair
    # most_frequent_pair = (your code here)
    # YOUR CODE HERE
    raise NotImplementedError()
    print(f'Iter {i}: Most frequent pair: {most_frequent_pair}')

    # Merge the code in the vocabulary
    # vocab = (your code here)
    # YOUR CODE HERE
    raise NotImplementedError()

  # Find the tokens and how often they occur in the vocabulary one last time
  # tokens = (your code here)
  tokens = get_tokens_and_frequencies(vocab)

  return tokens, vocab

In [None]:
tokens, vocab = tokenize(text, num_merges=22)

In [None]:
print(f'Tokens: {tokens}')
print(f'Number of tokens: {len(tokens)}')
print(f'Vocabulary: {vocab}')
print(f'Size of vocabulary: {len(vocab)}')

After 22 merges of the tokenizer, you should have 27 tokens and the vocabulary consisting of 18 different tokens.

In [None]:
# Test code. Do not edit.

assert len(vocab) == 18
assert len(tokens) == 27