In [3]:
from __future__ import annotations

import os
import regex as re
from collections.abc import Iterable
from typing import IO, Any, BinaryIO
from tqdm import tqdm

import numpy.typing as npt
import torch
from jaxtyping import Bool, Float, Int
from torch import Tensor

from concurrent.futures import ThreadPoolExecutor

In [4]:
def find_chunk_boundaries(
    file: BinaryIO,
    desired_num_chunks: int,
    split_special_token: bytes,
) -> list[int]:
    """
    Chunk the file into parts that can be counted independently.
    May return fewer chunks if the boundaries end up overlapping.
    """
    assert isinstance(split_special_token, bytes), "Must represent special token as a bytestring"

    # Get total file size in bytes
    file.seek(0, os.SEEK_END)
    file_size = file.tell()
    file.seek(0)

    chunk_size = file_size // desired_num_chunks

    # Initial guesses for chunk boundary locations, uniformly spaced
    # Chunks start on previous index, don't include last index
    chunk_boundaries = [i * chunk_size for i in range(desired_num_chunks + 1)]
    chunk_boundaries[-1] = file_size

    mini_chunk_size = 4096  # Read ahead by 4k bytes at a time

    for bi in range(1, len(chunk_boundaries) - 1):
        initial_position = chunk_boundaries[bi]
        file.seek(initial_position)  # Start at boundary guess
        while True:
            mini_chunk = file.read(mini_chunk_size)  # Read a mini chunk

            # If EOF, this boundary should be at the end of the file
            if mini_chunk == b"":
                chunk_boundaries[bi] = file_size
                break

            # Find the special token in the mini chunk
            found_at = mini_chunk.find(split_special_token)
            if found_at != -1:
                chunk_boundaries[bi] = initial_position + found_at
                break
            initial_position += mini_chunk_size

    # Make sure all boundaries are unique, but might be fewer than desired_num_chunks
    return sorted(set(chunk_boundaries))

In [37]:
import time
def train_bpe(
    input_path: str | os.PathLike,
    vocab_size: int,
    special_tokens: list[str],
    **kwargs,
) -> tuple[dict[int, bytes], list[tuple[bytes, bytes]]]:
    """Given the path to an input corpus, run train a BPE tokenizer and
    output its vocabulary and merges.

    Args:
        input_path (str | os.PathLike): Path to BPE tokenizer training data.
        vocab_size (int): Total number of items in the tokenizer's vocabulary (including special tokens).
        special_tokens (list[str]): A list of string special tokens to be added to the tokenizer vocabulary.
            These strings will never be split into multiple tokens, and will always be
            kept as a single token. If these special tokens occur in the `input_path`,
            they are treated as any other string.

    Returns:
        tuple[dict[int, bytes], list[tuple[bytes, bytes]]]:
            vocab:
                The trained tokenizer vocabulary, a mapping from int (token ID in the vocabulary)
                to bytes (token bytes)
            merges:
                BPE merges. Each list item is a tuple of bytes (<token1>, <token2>),
                representing that <token1> was merged with <token2>.
                Merges are ordered by order of creation.
    """
    # 1. Pre-tokenization
    PAT = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
    pre_tokens_dict: dict[tuple[bytes], int] = {}
    # Parallel processing of chunks
        
    with open(input_path, "rb") as f:
        num_process = 10
        boundaries = find_chunk_boundaries(f, num_process, b'<|endoftext|>')
        # Single threaded processing for comparison
        for start, end in zip(boundaries[:-1], boundaries[1:]):
            f.seek(start)
            chunk = f.read(end - start).decode('utf-8', errors='ignore')
            #split the chunk on special tokens
            parts = re.split("|".join(special_tokens), chunk)
            for part in parts:
                tokens = re.finditer(PAT, part)
                for match in tokens:
                    token_bytes = match.group(0).encode('utf-8')
                    token_tuple = (token_bytes)
                    pre_tokens_dict[token_tuple] = pre_tokens_dict.get(token_tuple, 0) + 1
    vocab = {}
    merges = []
    #bulid initial vocab with special tokens and ascii characters
    for i in range(256):
        vocab[i] = bytes([i])
    for i, special_token in enumerate(special_tokens):
        vocab[i+256] = special_token.encode('utf-8')
    
    current_size = 256 + len(special_tokens)
    current_tokens = {}  #store current tokens to be merged
    time_taken1 = 0.0
    time_taken2 = 0.0
    pair_counts: dict[tuple[bytes, bytes], int] = {}
    loop_number = 0
    while(current_size < vocab_size):

        start_time1 = time.time()
        if loop_number == 0:
            for key in pre_tokens_dict.keys():
                if key not in current_tokens:
                    current_tokens[key] = list(key)
                tokens = current_tokens[key]
                for i in range(len(tokens)-1):
                    pair  = (tokens[i], tokens[i+1])
                    pair_counts[pair] = pair_counts.get(pair, 0) + pre_tokens_dict[key]
        else:
            for key in modified_words:
                old_tokens = modified_words[key]
                new_tokens = current_tokens[key]
                # remove counts for old tokens
                for i in range(len(old_tokens)-1):
                    pair  = (old_tokens[i], old_tokens[i+1])
                    pair_counts[pair] = pair_counts.get(pair, 0) - pre_tokens_dict[key]
                    if pair_counts[pair] == 0:
                        del pair_counts[pair]
                # add counts for new tokens
                for i in range(len(new_tokens)-1):
                    pair  = (new_tokens[i], new_tokens[i+1])
                    pair_counts[pair] = pair_counts.get(pair, 0) + pre_tokens_dict[key]
        end_time1 = time.time()        
        # find the most frequent pair
        # if there are multiple pairs with same frequency, choose one with greatest lex order
        def pair_sorter(item):
            bytes1 = vocab[item[0][0]]
            bytes2 = vocab[item[0][1]]
            return (item[1], bytes1, bytes2)
        # return first pair
        best_pair = max(pair_counts.items(), key=pair_sorter)[0]

        # add new token to vocab
        vocab[current_size] = vocab[best_pair[0]] + vocab[best_pair[1]]
        merges.append((vocab[best_pair[0]], vocab[best_pair[1]]))

        start_time2 = time.time()
        # update current_tokens
        modified_words = {}
        for key in pre_tokens_dict.keys():
            tokens = current_tokens[key]
            i = 0
            modified_flag = False
            new_tokens = []
            n = len(tokens)
            while i < len(tokens):
                if i < len(tokens) - 1 and (tokens[i], tokens[i+1]) == best_pair:
                    new_tokens.append(current_size)
                    i += 2
                    modified_flag = True
                else:
                    new_tokens.append(tokens[i])
                    i += 1
            if modified_flag:
                # store the modified words before updating
                modified_words[key] = current_tokens[key]
            current_tokens[key] = new_tokens
        current_size += 1
        end_time2 = time.time()

        time_taken1 += end_time1 - start_time1
        time_taken2 += end_time2 - start_time2
        loop_number += 1
    print(f"Time taken for counting pairs: {time_taken1}, Time taken for updating tokens: {time_taken2}")
    # ajust value order, move special tokens to the front
    new_vocab = {}
    for i in range(vocab_size):
        if i < 256 + len(special_tokens):
            if i < len(special_tokens):
                new_vocab[i] = vocab[i+256]
            else:
                new_vocab[i] = vocab[i - len(special_tokens)]
        else:
            new_vocab[i] = vocab[i]
    vocab = new_vocab
        
    return vocab, merges

In [38]:
file_path1 = "/home/std10/extend/TinyStoriesV2-GPT4-valid.txt"
file_path2 = "./tests/fixtures/corpus.en"
vocab, merges = train_bpe(file_path2, vocab_size=500, special_tokens=['<|endoftext|>'])


Time taken for counting pairs: 0.04724431037902832, Time taken for updating tokens: 0.7574045658111572


In [39]:
#寻找merge和 train-bpe-reference.py中merge的不同
with open("./tests/fixtures/train-bpe-reference-merges.txt", "rb") as f:
    reference_merge = []
    for line in f:
        parts = line.strip().split(b" ")
        reference_merge.append((parts[0], parts[1]))
for i in range(len(merges)):
    if b' ' in merges[i][0] or b' ' in merges[i][1]:
        continue
    if merges[i][:2] != reference_merge[i]:
        
        print(f"Difference at index {i}:")
        print(f"  Computed: {merges[i]}")
        print(f"  Reference: {reference_merge[i]}")

In [43]:
class Tokenizer():
    def __init__(self, 
                 vocab: dict[int, bytes], 
                 merges: list[tuple[bytes, bytes]], 
                 special_tokens: list[str]=None):
        self.vocab = vocab
        self.merges = merges
        self.special_tokens = special_tokens if special_tokens is not None else []
        self.token_to_id = {token: idx for idx, token in vocab.items()}
    
    def from_files(cls, 
                   vocab_filepath: str, 
                   merges_filepath: str,
                   special_tokens: list[bytes]=None) -> Tokenizer:
        vocab: dict[int, bytes] = {}
        merges: list[tuple[bytes, bytes]] = []
        with open(vocab_filepath, 'rb') as vf:
            for line in vf:
                token = line.rstrip(b'\n')
                vocab[len(vocab)] = token
        with open(merges_filepath, 'rb') as mf:
            for line in mf:
                token1, token2 = line.rstrip(b'\n').split(b' ')
                merges.append((token1, token2))
        return cls(vocab, merges, special_tokens)
    
    def encode(self, text: str) -> list[int]:
        parts = re.split("|".join(self.special_tokens), text)
        # save all special tokens_ids
        
        whole_token_ids = []
        for part in parts:
            PAT = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
            tokens = re.finditer(PAT, part)
            token_ids = []
            for match in tokens:
                token_bytes = match.group(0).encode('utf-8')
                token_id = self.token_to_id.get(token_bytes, None)
                if token_id is not None:
                    token_ids.append(token_id)
                else:
                    # Apply BPE merges
                    sub_tokens = [token_bytes]
                    for merge in self.merges:
                        new_sub_tokens = []
                        i = 0
                        while i < len(sub_tokens):
                            if i < len(sub_tokens) - 1 and (sub_tokens[i], sub_tokens[i+1]) == merge:
                                new_sub_tokens.append(self.vocab[len(self.vocab) + self.merges.index(merge)])
                                i += 2
                            else:
                                new_sub_tokens.append(sub_tokens[i])
                                i += 1
                        sub_tokens = new_sub_tokens
                    for sub_token in sub_tokens:
                        sub_token_id = self.token_to_id.get(sub_token, None)
                        if sub_token_id is not None:
                            token_ids.append(sub_token_id)
                        else:
                            raise ValueError(f"Sub-token {sub_token} not found in vocabulary.")
            whole_token_ids.extend(token_ids)
        return whole_token_ids

In [6]:
import numpy as np
data_path = '/home/std10/extend/generated_data/tokenized_data_train.npy'
tokenized_data = np.load(data_path, allow_pickle=True, mmap_mode='r')