In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import copy

import torch
import torch.nn.functional as F
from torch import nn
from tqdm.auto import tqdm
import random
from collections import Counter
import re

In [2]:
!wget "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"

--2025-05-14 18:05:11--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt’


2025-05-14 18:05:11 (12.5 MB/s) - ‘input.txt’ saved [1115394/1115394]



In [3]:
with open('/content/input.txt', mode='r', encoding='utf-8') as f:
    text = f.read()

In [4]:
print('Length of data: ', len(text))
print(text[:100])

Length of data:  1115394
First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You


## **Tokenizers**

Here are different **SUBWORD** tokenizers:

<br>

Tokenizer Type | Library/Example | Strength | Weakness | Used in
---------------|-----------------|----------|----------|---------
BPE | HuggingFace, GPT-2 | Fast, simple | Unstructured merges | GPT-2, RoBERTa |
Unigram LM | SentencePiece | Probabilistic, flexible | More complex to train | T5, mT5 |
WordPiece | TensorFlow/BERT | Good with rare words | Slower training than BPE | BERT, ALBERT |
SentencePiece Framework | SentencePiece | Handles whitespace + full text | Adds special chars | T5, XLNet |
Character-level | Custom / simple RNNs | No OOV, trivial implementation | Long sequences | Old RNNs, toy tasks |
Byte-level BPE | GPT-2 tokenizer | Handles any input data | Harder to read tokens | GPT-2, GPT-3

<br>

In previous notebook we implemented **BPE** tokenizer, since we wanted it for GPT models. Now we are going to implement **WordPiece** tokenizer for BERT implementation.

<br>

1.  **Byte Pair Encoding (BPE)**

Start with characters → iteratively merge most frequent adjacent pairs into new tokens.

- **Example:**
```python
"low lower lowest"
# Split to character level
→ ['l', 'o', 'w', 'l', 'o', 'w','e', 'r' ,'l', 'o', 'w', 'e', 's', 't']
# → Merge ('l', 'o') → 'lo'
→ ['lo', 'w', 'lo', 'w','e', 'r' ,'lo', 'w', 'e', 's', 't']
# Merge ('lo', 'w') → 'low'
→ ['low', 'low','e', 'r' ,'low', 'e', 's', 't']
→ ...
```

- **Pros:**
	- Efficient
    - Deterministic
    - Easy to implement (you already did!)
    - Good balance between vocab size and sequence length

- **Cons:**
    - Doesn’t model word boundaries explicitly
    - Can create strange merges (e.g., merging across morpheme boundaries)

> GPT-2, RoBERTa, CLIP. <br> When you want fast, trainable tokenization with high performance

<br>

2. **WordPiece:**

Similar to BPE but slightly different merge criterion:
WordPiece merges tokens to maximize the likelihood of a language model, not just frequency.

- **Example:**

"playing" might become ["play", "##ing"]
(## indicates subword continuation)

- **Pros:**
	- Handles rare words and OOV (out-of-vocabulary) tokens well
	- Popular and effective

- **Cons:**
	- Slower to train than BPE
	- May break words in unnatural places (like BPE)

> BERT, ALBERT, DistilBERT. <br> Best when used with bidirectional Transformer models

<br>

- GPT recommendations:

<br>

Use case | Recommended Tokenizer
---------|-----------------------
Large transformer decoder | BPE or Byte-level BPE |
Encoder-decoder models | Unigram (SentencePiece) |
Multilingual input | SentencePiece |
Highest compression | WordPiece or Unigram |
Fastest to implement | BPE

<br>

> Here We are going to tokenize *tiny Shakespeare*, which is poetic.

## **BPE Tokenizer**

In [5]:
def merge(words, pair):
    new_words = []
    for word in words:
        pairs = [[word[i] + ' ' + word[i+1]] for i in range(len(word) - 1)]
        if pair in [''.join(pair) for pair in pairs]:
            item = ' '.join(word)
            item = item.replace(' '.join(pair.split()), ''.join(pair.split()))
            new_words.append(item.split())
        else:
            new_words.append(word)

    return new_words

def PBE(text, num_iteration=1000, freq_limit=5, pattern=r"\s*\S+"):
    # Tokenize input text into list of character lists per word
    words = [[c for c in item] for item in re.findall(pattern, text)]

    pbar = tqdm(range(num_iteration))  # Show progress bar

    for _ in pbar:
        pairs = []
        for item in words:
            # Collect all adjacent character pairs from words
            pairs += [item[i] + ' ' + item[i+1] for i in range(len(item) - 1)]

        print(pairs[:10])
        # Find most common pair
        max = Counter(pairs).most_common(1)[0]

        if max[1] > freq_limit:
            pair = max[0]
            words = merge(words, pair)  # Merge the pair in all words
        else:
            break  # Stop if no pair is frequent enough

        pbar.set_description(f"Max frequency: {max[1]}, ")

    # Flatten the list of words into a single list of subwords/tokens
    return [w for word in words for w in word]

In [6]:
tokens = PBE(text, num_iteration=1)

  0%|          | 0/1 [00:00<?, ?it/s]

['F i', 'i r', 'r s', 's t', '  C', 'C i', 'i t', 't i', 'i z', 'z e']


## **WordPiece**

In [7]:
print(text[:500])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor


In [8]:
#  ^ → start of line
#  .*?:.* → any content that includes : (non-greedy before the :) --> (*?: zero or more characters, non-greedily)
#  $ → end of line
#  \n? → optionally match the newline
#  flags=re.MULTILINE → makes ^ and $ apply to each line

pattern = r'^.*?:.*$\n?'  # matches any line that contains a colon
cleaned = re.sub(pattern, '', text, flags=re.MULTILINE)
print(cleaned[:500])

Before we proceed any further, hear me speak.

Speak, speak.

You are all resolved rather to die than to famish?

Resolved. resolved.

First, you know Caius Marcius is chief enemy to the people.

We know't, we know't.

Let us kill him, and we'll have corn at our own price.
Is't a verdict?


One word, good citizens.

We are accounted poor citizens, the patricians good.
would yield us but the superfluity, while it were
wholesome, we might guess they relieved us humanely;
afflicts us, the object of


In [9]:
pattern = r'[.?!]\s*'  # Split on ., ?, ! followed by any space
sentences = re.split(pattern, cleaned)
words = [s.split() for s in sentences if len(s.split()) > 3]
words[1]

['You',
 'are',
 'all',
 'resolved',
 'rather',
 'to',
 'die',
 'than',
 'to',
 'famish']

In [10]:
pattern = r'^.*?:.*$\n?'
cleaned = re.sub(pattern, '', text, flags=re.MULTILINE)

# Split sentences
pattern = r'[.?!]\s*'
sentences = re.split(pattern, cleaned)

# Only keep sentences with at least 4 words
words = [s.split() for s in sentences if len(s.split()) > 3]

# Flatten the list of word lists
flatten_words = [word for sentence in words for word in sentence]

# Initialize vocabulary with characters
chars = list(set(''.join(flatten_words)))
chars.extend(['[UNK]', '[CLS]', '[SEP]'])

# Convert each word to a list of its characters
character_lvl_words = [[ch for ch in word] for word in flatten_words]

In [11]:
# Merge adjacent character pairs in tokenized words
def merge(words, pair_str):
    a, b = pair_str.split()  # Split the pair string into two parts
    new_words = []

    for word in words:
        i = 0
        new_word = []
        while i < len(word):
            # Merge the pair (a, b) if found
            if i < len(word) - 1 and word[i] == a and word[i + 1] == b:
                merged_token = a + b
                new_word.append(merged_token)
                i += 2
            else:
                new_word.append(word[i])
                i += 1
        new_words.append(new_word)

    return new_words

# Add WordPiece-style prefixes (##) to non-initial subword tokens
def add_wordpiece_prefixes(token_list):
    result = []
    for word in token_list:
        if not word:
            continue
        result.append([word[0]] + ['##' + t for t in word[1:]])
    return result

### **Likelihood**

1. **Use entropy(counts):**

    If you want to reduce uncertainty and encourage compact token distributions — closer to BERT’s goal in WordPiece.
2. **Use sum(counts * log(probs)):**
    
    If you want a likelihood-based heuristic that reflects how much total “confidence” the model has in the token set.

3. **frequency-based heuristic**

    When we want a fast, lightweight, and non-probabilistic training process.

$$
\text{score}(a, b) = \frac{\text{freq}(a \, b)}{\text{freq}(a) \times \text{freq}(b)}
$$

> freq(a b) = frequency of the pair (a, b) appearing together (adjacent).<br>
freq(a) = how often token a appears in the corpus.<br>
freq(b) = how often token b appears in the corpus.

In [12]:
def log_likelihood_sum(before, after):
    def sum_counts(counts):
        # Normalize to get probabilities
        probs = counts / counts.sum()
        return np.sum(counts * np.log(probs + 1e-12))  # small epsilon for stability

    # Flatten and count tokens using Counter
    flat_before = [token for word in before for token in word]
    flat_after = [token for word in after for token in word]

    before_counts = np.array(list(Counter(flat_before).values()), dtype=np.float64)
    after_counts = np.array(list(Counter(flat_after).values()), dtype=np.float64)

    return sum_counts(after_counts) - sum_counts(before_counts)

def log_likelihood_entropy(before, after):
    def entropy(counts):
        # Normalize to get probabilities
        probs = counts / counts.sum()
        return -np.sum(probs * np.log(probs + 1e-12))  # small epsilon for stability

    # Flatten and count tokens using Counter
    flat_before = [token for word in before for token in word]
    flat_after = [token for word in after for token in word]

    before_counts = np.array(list(Counter(flat_before).values()), dtype=np.float64)
    after_counts = np.array(list(Counter(flat_after).values()), dtype=np.float64)

    return entropy(before_counts) - entropy(after_counts)  # Δ entropy: positive means after-merge distribution is more compact (lower entropy)

In [13]:
def clean_text(text):
    # Remove metadata lines (e.g., "speaker: text")
    pattern = r'^.*?:.*$\n?'
    cleaned = re.sub(pattern, '', text, flags=re.MULTILINE)

    # Split sentences
    pattern = r'[.?!]\s*'
    sentences = re.split(pattern, cleaned)

    # Only keep sentences with at least 4 words
    words = [s.split() for s in sentences if len(s.split()) > 3]
    sentences = [s for s in sentences if len(s.split()) > 3]

    return words, sentences

In [14]:
def WordPiece(text, num_iteration=1000, freq_limit=5, likeli_iteration=30, load_char_lvl=False):

    if load_char_lvl:
        character_lvl_words = np.load(load_char_lvl, allow_pickle=True).tolist()

    else:
        words, _ = clean_text(text)

        # Flatten the list of word lists
        flatten_words = [word for sentence in words for word in sentence]

        # Convert each word to a list of its characters
        character_lvl_words = [[ch for ch in word] for word in flatten_words]

    # Progress bar for the outer loop
    pbar = tqdm(range(num_iteration), desc="Outer loop progress")

    for _ in pbar:
        # Count character pair frequencies
        pairs = {}
        for word in character_lvl_words:
            for pair in zip(word[:-1], word[1:]):
                pairs[pair] = pairs.get(pair, 0) + 1

        best_prob = float('-inf')
        best_pair = None
        best_merged = None

        # Get the most common pairs and limit iterations (e.g., 100 or fewer)
        common_pairs = Counter(pairs).most_common(likeli_iteration)

        # Create the inner progress bar only once, outside the outer loop
        inner_pbar = tqdm(common_pairs, desc="Evaluating pairs", leave=False)

        for pair, freq in inner_pbar:
            if freq > freq_limit:
                pair_str = ' '.join(pair)
                merged = merge(character_lvl_words, pair_str)
                prob_dif = log_likelihood_entropy(character_lvl_words, merged) # We can change the function to `log_likelihood_sum`
                if prob_dif > best_prob:
                    best_prob = prob_dif
                    best_pair = pair
                    best_merged = merged

            # Update the inner progress bar description with current progress
            inner_pbar.set_postfix({"Best pair": best_pair, "Δ log-likelihood": f"{best_prob:.4f}"})

        if best_pair is None:
            break  # Stop if no good merges are found

        character_lvl_words = best_merged
        np.save('/content/WordPiececharacter_lvl_words.npy', np.array(character_lvl_words, dtype=object))
        # !cp /content/WordPiececharacter_lvl_words.npy /content/drive/MyDrive/LLM/
        pbar.set_description(f"Best pair: {best_pair}, Δ log-likelihood: {best_prob:.4f}")

    # Apply WordPiece prefixes and return unique tokens
    tokens = add_wordpiece_prefixes(character_lvl_words)
    flat_tokens = [tok for word in tokens for tok in word]
    # Add special characters to tokens
    flat_tokens += ['[UNK]', '[CLS]', '[SEP]', '.'] # '[CLS]', '[SEP]' are not necessary, since I want train MLM (Maked language model)
    return sorted(set(flat_tokens))

In [15]:
tokens = WordPiece(
    text, num_iteration=1000-999,
    likeli_iteration=30,
    load_char_lvl='/content/drive/MyDrive/LLM/WordPiececharacter_lvl_words.npy'
)

Outer loop progress:   0%|          | 0/1 [00:00<?, ?it/s]

Evaluating pairs:   0%|          | 0/30 [00:00<?, ?it/s]

In [16]:
len(tokens)

1757

In [17]:
def wordpiece_tokenize(word, tokens_list, unk_token='[UNK]'):
    # List to store the resulting tokens
    tokeninzed = []
    start = 0

    # Continue until the entire word is processed
    while start < len(word):
        end = len(word)
        matched = None

        # Try to find the longest substring starting at 'start' that exists in the vocabulary
        while start < end:
            substr = word[start:end]

            # For non-initial tokens, add WordPiece continuation prefix
            if start > 0:
                substr = "##" + substr

            # Check if this substring is in the vocabulary
            if substr in tokens_list:
                matched = substr
                break  # Found the longest match

            # Shorten the substring by one character from the right
            end -= 1

        # If no match found, treat the entire word (or remaining part) as unknown
        if matched is None:
            tokeninzed.append(unk_token)
            break

        # Add matched token to result
        tokeninzed.append(matched)

        # Update the start position:
        # If it's a continuation token (starts with ##), move forward by the number of characters matched (excluding ##)
        # If it's the initial token, move by the full length of the token
        start = end if matched.startswith("##") else len(matched)

    return tokeninzed

In [18]:
# Check the tokenizer
for item in ['resolve', 'Resolve', 'speak', 'relieved', 'object', 'objection']:
    tokenized = wordpiece_tokenize(item, tokens)
    print(f"{item:.^12} tokens are: {tokenized}")

..resolve... tokens are: ['r', '##e', '##so', '##l', '##ve']
..Resolve... tokens are: ['Re', '##so', '##l', '##ve']
...speak.... tokens are: ['speak']
..relieved.. tokens are: ['r', '##el', '##ie', '##ve', '##d']
...object... tokens are: ['o', '##b', '##ject']
.objection.. tokens are: ['o', '##b', '##ject', '##ion']


In [19]:
# Make sure that '[PAD]' has 0 index
tokens.remove('[UNK]')
for special in ['[UNK]', '[MASK]', '[PAD]']:
        tokens.insert(0, special)

In [21]:
tokens[:5]

['[PAD]', '[MASK]', '[UNK]', '##$', "##'"]

In [22]:
stoi = {ch:i for i, ch in enumerate(tokens)} # String to int
itos = {i:ch for i, ch in enumerate(tokens)} # Int to string

decode_word_piece = lambda string, tokens: [stoi[item] if item in stoi else stoi['[UNK]']
                                            for item in wordpiece_tokenize(string, tokens)]
encode_word_piece = lambda ids: ''.join([tok if not tok.startswith('##') else tok[2:]
                                         for tok in [itos[item] for item in ids]])

In [23]:
# Let's test the decoder-encoder
ids = decode_word_piece('hello!', tokens)
print(ids, '-->', [itos[id] for id in ids])

encode_word_piece(ids)

[1317, 466, 2] --> ['hel', '##lo', '[UNK]']


'hello[UNK]'

In [24]:
_, cleaned_text = clean_text(text)
print(f'Length of data: {len(cleaned_text):^20}')
print(f'Maxmum number of words in a sentence: {len(max(cleaned_text, key=len).split()):^20}')
print(f'Minimum number of words in a sentence: {len(min(cleaned_text, key=len).split()):^20}')

Length of data:        10312        
Maxmum number of words in a sentence:         192         
Minimum number of words in a sentence:          4          


In [25]:
def split_poetic_text(text, max_words=40, min_words=10):
    # Split text by strong punctuation first
    clauses = re.split(r'(?<=[,;])\s+', text)

    segments = []
    current = []

    for clause in clauses:
        current.append(clause.strip())
        word_count = sum(len(part.split()) for part in current)

        if word_count >= max_words or (
            word_count >= min_words and clause.strip().endswith((',', ';'))
        ):
            segments.append(' '.join(current).strip())
            current = []

    # Add remaining
    if current:
        segments.append(' '.join(current).strip())

    return segments

In [26]:
new_text = []
for item in cleaned_text:
    item = item.replace('\n', ' ')
    if len(item.split()) > 50:
        new_text.append(split_poetic_text(item)[0])
    else:
        new_text.append(item)

In [27]:
max_len = 90
i = 0
while True:
    if len(new_text[i].split()) > max_len:
        new_text[i] = new_text[i][:max_len]
        i += 1
        continue

    merged  = new_text[i] + ' . ' + new_text[i+1]
    if len(merged.split()) > max_len:
        i += 1
        continue

    new_text[i] = merged
    del new_text[i+1]

    if i >= len(new_text)-2:
        break

new_text = new_text[:-1]

In [28]:
print(f'Length of data: {len(new_text):^20}')
print(f'Maxmum number of words in a sentence: {len(max(new_text, key=len).split()):^20}')
print(f'Minimum number of words in a sentence: {len(min(new_text, key=len).split()):^20}')

Length of data:         1933        
Maxmum number of words in a sentence:          87         
Minimum number of words in a sentence:          40         


In [29]:
tokenized_text = []
for i, item in (enumerate(new_text)):
    item = [wordpiece_tokenize(i, tokens) for i in item.split()]
    tokenized_text.append([i for i in item for i in i])

In [30]:
print(f'Length of data: {len(tokenized_text):^20}')
print(f'Maxmum number of tokens in a sentence: {len(max(tokenized_text, key=len)):^20}')
print(f'Minimum number of tokens in a sentence: {len(min(tokenized_text, key=len)):^20}')

Length of data:         1933        
Maxmum number of tokens in a sentence:         218         
Minimum number of tokens in a sentence:          73         


In [31]:
tokenized_text[0][:10]

['Be', '##fore', 'we', 'pro', '##ce', '##ed', 'any', 'fur', '##t', '##her,']