In [8]:
import os
import regex as re
import torch
from torch import nn
import torch.nn.functional as F

In [14]:
t = torch.arange(12, dtype=torch.float32).reshape((2, 2, -1))
print(t)
F.softmax(t, dim=1)

tensor([[[ 0.,  1.,  2.],
         [ 3.,  4.,  5.]],

        [[ 6.,  7.,  8.],
         [ 9., 10., 11.]]])


tensor([[[0.0474, 0.0474, 0.0474],
         [0.9526, 0.9526, 0.9526]],

        [[0.0474, 0.0474, 0.0474],
         [0.9526, 0.9526, 0.9526]]])

In [None]:
PAT = re.compile(r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")


def init_vocab(special_tokens: list[str]) -> dict[int, bytes]:
    vocab = {i: bytes([i]) for i in range(256)}

    for token in special_tokens:
        bytes_token = token.encode("utf-8")
        vocab[len(vocab)] = bytes_token
    return vocab

def word2bytes(word: str) -> tuple[bytes, ...]:
    a = list(word.encode("utf-8"))
    return tuple(bytes([b]) for b in a)

In [None]:
def train_bpe(
        input_path: str | os.PathLike,
        vocab_size: int,
        special_tokens: list[str],
        **kwargs,
) -> tuple[dict[int, bytes], list[tuple[bytes, bytes]]]:
    """Given the path to an input corpus, run train a BPE tokenizer and
    output its vocabulary and merges.

    Args:
        input_path (str | os.PathLike): Path to BPE tokenizer training data.
        vocab_size (int): Total number of items in the tokenizer's vocabulary (including special tokens).
        special_tokens (list[str]): A list of string special tokens to be added to the tokenizer vocabulary.
            These strings will never be split into multiple tokens, and will always be
            kept as a single token. If these special tokens occur in the `input_path`,
            they are treated as any other string.

    Returns:
        tuple[dict[int, bytes], list[tuple[bytes, bytes]]]:
            vocab:
                The trained tokenizer vocabulary, a mapping from int (token ID in the vocabulary)
                to bytes (token bytes)
            merges:
                BPE merges. Each list item is a tuple of bytes (<token1>, <token2>),
                representing that <token1> was merged with <token2>.
                Merges are ordered by order of creation.
    """
    # 1. Vocab
    # dict[int, bytes]
    vocab: dict[int, bytes] = init_vocab(special_tokens)

    # dict[tuple[bytes], int]
    pre_token_counts: dict[tuple[bytes, ...], int] = {}

    # 2. Pre-tokenization
    with open(input_path, "rb") as f:
        num_processes = 4
        boundaries = find_chunk_boundaries(f, num_processes, b"<|endoftext|>")

        for start, end in zip(boundaries[:-1], boundaries[1:]):
            f.seek(start)
            chunk = f.read(end - start).decode("utf-8", errors="ignore")

            # remove special tokens
            separator = "|".join([re.escape(token) for token in special_tokens])
            separator_pat = re.compile(separator)
            sub_chunks = separator_pat.split(chunk)
            for sub_chunk in sub_chunks:
                print(sub_chunk)

            # Run pre-tokenization on your chunk and store the counts for each pre-token
            for sub_chunk in sub_chunks:
                for match in PAT.finditer(sub_chunk):
                    token = match.group()
                    if token in special_tokens:
                        continue
                    token_tuple = word2bytes(token)
                    print(token, token_tuple)
                    if len(token_tuple) < 2:
                        continue
                    pre_token_counts[token_tuple] = pre_token_counts.get(token_tuple, 0) + 1
    print("pre_token_counts:{}".format(pre_token_counts))

    # 3. Merges
    token_pairs_counts: dict[tuple[bytes, bytes], int] = {}
    merges: list[tuple[bytes, bytes]] = []
    for pre_token, count in pre_token_counts.items():
        for pair in zip(pre_token[:-1], pre_token[1:]):
            # (bytes, bytes)
            token_pairs_counts[pair] = token_pairs_counts.get(pair, 0) + count

    # 4. Merges
    while len(vocab) < vocab_size:
        # get the most frequent pair, if there are multiple pairs with the same frequency, choose the lexicographically greater pair
        sorted_token_pairs_counts = sorted(token_pairs_counts.items(), key=lambda x: (x[1], x[0]), reverse=True)
        print(sorted_token_pairs_counts)
        most_frequent_pair = sorted_token_pairs_counts[0][0]

        merges.append(most_frequent_pair)
        merged_token = most_frequent_pair[0] + most_frequent_pair[1]
        vocab[len(vocab)] = merged_token
        token_pairs_counts.pop(most_frequent_pair)

        # replace the pre_token with the merged token
        for pre_token, count in list(pre_token_counts.items()):
            n = len(pre_token)
            if not any(pre_token[i] == most_frequent_pair[0] and pre_token[i + 1] == most_frequent_pair[1] for i in
                       range(n - 1)):
                continue

            # generate the new pre_token
            new_pre_token: list[bytes] = []
            i = 0
            while i < n:
                # match the most frequent pair
                if i < n - 1 and pre_token[i] == most_frequent_pair[0] and pre_token[i + 1] == most_frequent_pair[1]:
                    new_pre_token.append(merged_token)
                    # update the prefix pair count
                    if i > 0:
                        old_pre_pair = (pre_token[i - 1], pre_token[i])
                        token_pairs_counts[old_pre_pair] = token_pairs_counts[old_pre_pair] - count
                        if token_pairs_counts[old_pre_pair] == 0:
                            token_pairs_counts.pop(old_pre_pair)
                    # update the suffix pair count
                    if i < n - 2:
                        old_suffix_pair = (pre_token[i + 1], pre_token[i + 2])
                        token_pairs_counts[old_suffix_pair] = token_pairs_counts[old_suffix_pair] - count
                        if token_pairs_counts[old_suffix_pair] == 0:
                            token_pairs_counts.pop(old_suffix_pair)
                    i += 2
                else:
                    new_pre_token.append(pre_token[i])
                    i += 1

            # count the new_pre_token
            for i in range(len(new_pre_token) - 1):
                pair = (new_pre_token[i], new_pre_token[i + 1])
                token_pairs_counts[pair] = token_pairs_counts.get(pair, 0) + count

            # update the pre_token_counts with the new_pre_token
            pre_token_counts[tuple(new_pre_token)] = count
            pre_token_counts.pop(tuple(pre_token))
        print("after merge pre_token_counts:{}".format(pre_token_counts))

    return vocab, merges

## for test

In [39]:
path = "../data/test_bpe.txt"
vocab, merges = train_bpe(path, 258, ["<|endoftext|>"])

low low low low low


lower lower widest widest widest

low (b'l', b'o', b'w')
 low (b' ', b'l', b'o', b'w')
 low (b' ', b'l', b'o', b'w')
 low (b' ', b'l', b'o', b'w')
 low (b' ', b'l', b'o', b'w')

 (b'\r', b'\n')
 (b'\r',)

 (b'\n',)
lower (b'l', b'o', b'w', b'e', b'r')
 lower (b' ', b'l', b'o', b'w', b'e', b'r')
 widest (b' ', b'w', b'i', b'd', b'e', b's', b't')
 widest (b' ', b'w', b'i', b'd', b'e', b's', b't')
 widest (b' ', b'w', b'i', b'd', b'e', b's', b't')

 (b'\r', b'\n')


newest newest newest newest newest newest
 (b'\r',)

 (b'\n',)
newest (b'n', b'e', b'w', b'e', b's', b't')
 newest (b' ', b'n', b'e', b'w', b'e', b's', b't')
 newest (b' ', b'n', b'e', b'w', b'e', b's', b't')
 newest (b' ', b'n', b'e', b'w', b'e', b's', b't')
 newest (b' ', b'n', b'e', b'w', b'e', b's', b't')
 newest (b' ', b'n', b'e', b'w', b'e', b's', b't')
pre_token_counts:{(b'l', b'o', b'w'): 1, (b' ', b'l', b'o', b'w'): 4, (b'\r', b'\n'): 2, (b'l', b'o', b'w', b'e', b'r'): 1, (b' ', b'l', b'o',