In [1]:
!pip install -q torch torchvision matplotlib

In [2]:
!wget -q https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt -O shakespeare.txt

with open("shakespeare.txt", "r", encoding="utf-8") as f:
    text = f.read()

print(text[:500])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor


In [3]:
from collections import Counter
from typing import List, Tuple, Dict

def corpus_to_tokens(text: str) -> List[str]:

    return list(text)

def get_pair_stats(tokens: List[str]) -> Dict[Tuple[str, str], int]:

    stats = Counter()
    if len(tokens) < 2:
        return stats
    prev = tokens[0]
    for t in tokens[1:]:
        stats[(prev, t)] += 1
        prev = t
    return stats

def merge_tokens(tokens: List[str], pair: Tuple[str, str]) -> List[str]:

    a, b = pair
    new_token = a + b
    new_tokens = []
    i = 0
    L = len(tokens)
    while i < L:
        if i < L - 1 and tokens[i] == a and tokens[i + 1] == b:
            new_tokens.append(new_token)
            i += 2
        else:
            new_tokens.append(tokens[i])
            i += 1
    return new_tokens

def fit_bpe(
    text: str,
    num_merges: int = 300,
    verbose: bool = True
):

    tokens = corpus_to_tokens(text)
    vocab = set(tokens)
    merges: List[Tuple[str, str]] = []

    for i in range(num_merges):
        stats = get_pair_stats(tokens)
        if not stats:
            if verbose:
                print(f"[Step {i}] no avialable pairs!")
            break

        # max freq pair
        best_pair, best_freq = max(stats.items(), key=lambda kv: kv[1])

        tokens = merge_tokens(tokens, best_pair)
        merges.append(best_pair)
        vocab.add(best_pair[0] + best_pair[1])

    return {
        "tokens": tokens,
        "vocab": vocab,
        "merges": merges    # merge rules
    }

In [4]:
NUM_MERGES = 300

bpe_model = fit_bpe(
    text,
    num_merges=NUM_MERGES,
    verbose=True
)

merges = bpe_model["merges"]
vocab  = bpe_model["vocab"]

print("Training done!")
print("total merge steps:", len(merges))
print("final vocab size:", len(vocab))

Training done!
total merge steps: 300
final vocab size: 365


In [5]:
def show_merges(merges, n_head=30, n_tail=10):
    print("=== first few merge rules ===")
    for i, (a, b) in enumerate(merges[:n_head]):
        print(f"{i+1:4d}: ({repr(a)}, {repr(b)}) -> {repr(a+b)}")

    if len(merges) > n_head + n_tail:
        print("...")
    if n_tail > 0:
        print("\n=== lasr few merge rules ===")
        for i, (a, b) in enumerate(merges[-n_tail:], start=len(merges)-n_tail+1):
            print(f"{i:4d}: ({repr(a)}, {repr(b)}) -> {repr(a+b)}")

show_merges(merges, n_head=20, n_tail=10)

=== first few merge rules ===
   1: ('e', ' ') -> 'e '
   2: ('t', 'h') -> 'th'
   3: ('t', ' ') -> 't '
   4: ('s', ' ') -> 's '
   5: ('d', ' ') -> 'd '
   6: (',', ' ') -> ', '
   7: ('o', 'u') -> 'ou'
   8: ('e', 'r') -> 'er'
   9: ('i', 'n') -> 'in'
  10: ('y', ' ') -> 'y '
  11: ('a', 'n') -> 'an'
  12: (':', '\n') -> ':\n'
  13: ('o', 'r') -> 'or'
  14: ('o', ' ') -> 'o '
  15: ('e', 'n') -> 'en'
  16: ('\n', '\n') -> '\n\n'
  17: ('a', 'r') -> 'ar'
  18: (' ', 'th') -> ' th'
  19: ('o', 'n') -> 'on'
  20: ('l', 'l') -> 'll'
...

=== lasr few merge rules ===
 291: (' a', 'm') -> ' am'
 292: (' th', 'y ') -> ' thy '
 293: ('O', 'N') -> 'ON'
 294: (' h', 'e ') -> ' he '
 295: ('in', ' ') -> 'in '
 296: ('r', 'an') -> 'ran'
 297: ('in', 'k') -> 'ink'
 298: ('s', 'a') -> 'sa'
 299: ('d', 'ea') -> 'dea'
 300: ('an', ' ') -> 'an '


In [6]:
def bpe_encode(text: str, merges: List[Tuple[str, str]]) -> List[str]:

    tokens = list(text)
    for pair in merges:
        tokens = merge_tokens(tokens, pair)
    return tokens

test_str = "Alas, poor Yorick! I knew him, Horatio:"
encoded = bpe_encode(test_str, merges)

print("initial string:")
print(test_str)
print("\nBPE encoded token seq:")
print(encoded)
print("\nNumer of tokens:", len(encoded))

initial string:
Alas, poor Yorick! I knew him, Horatio:

BPE encoded token seq:
['A', 'la', 's, ', 'po', 'or', ' ', 'Y', 'or', 'i', 'ck', '! ', 'I ', 'k', 'ne', 'w', ' him', ', ', 'H', 'or', 'at', 'i', 'o', ':']

Numer of tokens: 23
