In [1]:
import regex 
text = '''Once'''
PAT = r"""<\|endoftext\|>|\s+(?=<\|endoftext\|>)|'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
code = text.encode("utf-8")

In [2]:
import os
from typing import BinaryIO


def find_chunk_boundaries(
    file: BinaryIO,
    desired_num_chunks: int,
    split_special_token: bytes,
) -> list[int]:
    """
    Chunk the file into parts that can be counted independently.
    May return fewer chunks if the boundaries end up overlapping.
    """
    assert isinstance(split_special_token, bytes), "Must represent special token as a bytestring"

    # Get total file size in bytes
    file.seek(0, os.SEEK_END)
    file_size = file.tell()
    file.seek(0)

    chunk_size = file_size // desired_num_chunks

    # Initial guesses for chunk boundary locations, uniformly spaced
    # Chunks start on previous index, don't include last index
    chunk_boundaries = [i * chunk_size for i in range(desired_num_chunks + 1)]
    chunk_boundaries[-1] = file_size

    mini_chunk_size = 4096  # Read ahead by 4k bytes at a time
    
    for bi in range(1, len(chunk_boundaries) - 1):
        initial_position = chunk_boundaries[bi]
        file.seek(initial_position)  # Start at boundary guess
        while True:
            
            mini_chunk = file.read(mini_chunk_size)  # Read a mini chunk
            # If EOF, this boundary should be at the end of the file
            if mini_chunk == b"":
                chunk_boundaries[bi] = file_size
                break

            # Find the special token in the mini chunk
            found_at = mini_chunk.find(split_special_token)

            if found_at != -1:
                chunk_boundaries[bi] = initial_position + found_at
                break
            initial_position += mini_chunk_size

    # Make sure all boundaries are unique, but might be fewer than desired_num_chunks
    return sorted(set(chunk_boundaries))

with open("tests/fixtures/tinystories_sample.txt", "rb") as f:
    num_processes = 2
    boundaries = find_chunk_boundaries(f, num_processes, b"<|endoftext|>")

    # The following is a serial implementation, but you can parallelize this
    # by sending each start/end pair to a set of processes.
    for start, end in zip(boundaries[:-1], boundaries[1:]):
        f.seek(start)
        chunk = f.read(end - start).decode("utf-8", errors="ignore")
        
        
                # if n != 0:
                #     prior = (token[n-1],token[n])
                #     counter_buffer[prior] -= 1
                #     new_pair = (token[n-1],merged_pair)
                #     pair_trace[prior].remove(idx)
                    
                #     count = counter_buffer.get(new_pair,0) + 1
                #     counter_buffer[new_pair] = count
                #     new_trace = pair_trace.setdefault(new_pair, [])
                #     new_trace.append(idx)
                    
                # if n != length - 2:
                #     post = (token[n+1],token[n+2])
                #     counter_buffer[post] -= 1
                #     new_pair = (merged_pair,token[n+2])
                #     count = counter_buffer.get(new_pair,0) + 1
                #     counter_buffer[new_pair] = count
                #     pair_trace[post].remove(idx)
                #     new_trace = pair_trace.setdefault(new_pair, [])
                #     new_trace.append(idx)
       

In [16]:
from typing import List, Union, Tuple, OrderedDict,Dict
from tqdm import tqdm
from collections import OrderedDict
from collections import Counter
import itertools
import time

def get_token_pair(token:list)->List[Tuple]:
    """
    return pairs of given token
    token: ["12","46","22"]
    output: [("12","46"),("46","22")]
    """
    return list(zip(token[:-1],token[1:]))


def finalize_counter_buffer(counter_buffer,merged_pair,pair_trace,single_token_buffer):
    
    trace = set(pair_trace[merged_pair])
    a,b = merged_pair
    
    for idx in trace:
        token = single_token_buffer[idx]
        length = len(token)
        new_token = []
        n = 0
        prior_token_is_merged_token = False
        while n < length - 1:
            if token[n] == a and token[n+1] == b:
                
                new_token.append(merged_pair)
                # 修改计数 与 路径 
                if not prior_token_is_merged_token:
                    if n != 0:
                        prior = (token[n-1],token[n])
                        counter_buffer[prior] -= 1
                        new_pair = (token[n-1],merged_pair)
                        pair_trace[prior].remove(idx)
                        
                        count = counter_buffer.get(new_pair,0) + 1
                        counter_buffer[new_pair] = count
                        new_trace = pair_trace.setdefault(new_pair, [])
                        new_trace.append(idx)
                    
                if n != length - 2:
                    post = (token[n+1],token[n+2])
                    counter_buffer[post] -= 1
                    new_pair = (merged_pair,token[n+2])
                    count = counter_buffer.get(new_pair,0) + 1
                    counter_buffer[new_pair] = count
                    pair_trace[post].remove(idx)
                    new_trace = pair_trace.setdefault(new_pair, [])
                    new_trace.append(idx)
                n += 1 
                prior_token_is_merged_token = True
            else:
                new_token.append(token[n])
                prior_token_is_merged_token = False
            n += 1
        single_token_buffer[idx] = new_token
    
    counter_buffer[merged_pair] = 0




class BPE_trainier():
    def __init__(self, PAT = None):
        if PAT == None:
            self.PAT = r"""<\|endoftext\|>|\s+(?=<\|endoftext\|>)|'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
        else:
            self.PAT = PAT
        


    def train(self, file_path, vocab_size ,special_tokens = ["<|endoftext|>"]):
        
        merge = []
        start_time = time.time()
        # 读取并切分tokens
        with open(file_path, "r") as f:
            data = f.read()
        tokens:List[str] = regex.findall(self.PAT, data, regex.UNICODE)
        print("读取数据时间:", time.time()-start_time)
        start_time = time.time()
        # 初始化vocab
        # token_buffer:List[List[Tuple[int]]] = [] # 用来存储pair
        single_token_buffer:List[List[int]] = []
        
        counter_buffer:Dict[Tuple,int] = {} # pair的数量们

        pair_trace:Dict[List[int]] = {} # 记录某一个pair都在哪些token中出现过
        
        length = len(tokens)
        n = 0
        while n < length:
            token = tokens[n]
            
        
            token = list(token.encode("utf-8"))
            single_token_buffer.append(token) #得到单bytes的列表. 之后用来统计数量
            if len(token) >= 2:
                token_pair = get_token_pair(token)
                # token_buffer.append(token_pair) 
                for pair in token_pair:
                    tmp_count = counter_buffer.get(pair,0) + 1
                    counter_buffer[pair] = tmp_count
                    trace = pair_trace.setdefault(pair, [])
                    trace.append(n)
            n += 1
        
        all_bytes = itertools.chain.from_iterable(single_token_buffer)
        
        counter = Counter(all_bytes)
        
        vocab = {value:key for (key,value) in enumerate(counter.keys())}
        vocab_count = len(vocab)
        for special_token in special_tokens:
            vocab[special_token] = vocab_count + 1
            vocab_count += 1
        print("初始化时间:", time.time()-start_time)
        start_time = time.time()
        

        # 开始byte pair合并操作
        while vocab_count < vocab_size:
            
            
            vocab_count += 1

            max_count = max(counter_buffer.values())
            
            pairs = [pair for pair,count in counter_buffer.items() if count == max_count]
            
            merged_pair = pairs[0]
            merge.append(merged_pair)
            finalize_counter_buffer(
                counter_buffer = counter_buffer,
                merged_pair = merged_pair,
                pair_trace = pair_trace,
                single_token_buffer = single_token_buffer
            )
            counter[merged_pair] = max_count
            vocab[merged_pair] = vocab_count
            
            for token in merged_pair:
                counter[token] -= max_count
                if counter[token] == 0:
                    del vocab[token]
                    vocab_count -= 1
        print("循环耗时", time.time()-start_time)
        return merge, vocab



                

                

In [17]:
trainier = BPE_trainier()
import time
start_time = time.time()
merge,vocab = trainier.train(file_path="tests/fixtures/corpus.en",
               vocab_size=500)


print(time.time()-start_time)

读取数据时间: 0.04183340072631836
初始化时间: 0.19279718399047852
循环耗时 0.773115873336792
1.0119178295135498


In [13]:
vocab

{105: 0,
 114: 1,
 111: 2,
 110: 3,
 32: 4,
 99: 5,
 101: 6,
 109: 7,
 116: 8,
 115: 9,
 97: 10,
 100: 11,
 121: 12,
 102: 13,
 117: 14,
 112: 15,
 119: 16,
 104: 17,
 108: 18,
 98: 19,
 107: 20,
 103: 21,
 41: 23,
 46: 24,
 10: 25,
 44: 26,
 118: 27,
 80: 28,
 79: 29,
 66: 30,
 120: 31,
 78: 32,
 84: 33,
 65: 34,
 76: 35,
 87: 36,
 69: 37,
 68: 38,
 73: 39,
 77: 40,
 83: 41,
 63: 43,
 86: 44,
 53: 45,
 35: 46,
 45: 48,
 71: 49,
 48: 50,
 56: 51,
 49: 52,
 51: 53,
 50: 54,
 52: 55,
 57: 56,
 70: 57,
 38: 58,
 59: 59,
 85: 60,
 113: 61,
 67: 62,
 54: 63,
 47: 64,
 55: 65,
 81: 66,
 88: 67,
 194: 68,
 174: 69,
 72: 70,
 89: 71,
 173: 72,
 106: 73,
 42: 74,
 169: 76,
 195: 77,
 160: 78,
 122: 79,
 75: 81,
 82: 82,
 188: 83,
 161: 84,
 90: 85,
 95: 86,
 36: 87,
 37: 88,
 177: 89,
 147: 90,
 148: 91,
 61: 92,
 181: 93,
 164: 94,
 226: 95,
 130: 96,
 172: 97,
 182: 98,
 165: 99,
 151: 100,
 132: 101,
 162: 102,
 159: 103,
 239: 104,
 191: 105,
 189: 106,
 43: 107,
 '<|endoftext|>': 109,
 (32

In [5]:
s = [2,3,3,3]
s.remove(3)
s

[2, 3, 3]

In [6]:
vocab = {"a":[3],"b":3}

s = vocab.setdefault("a")
s.append(1)
vocab

{'a': [3, 1], 'b': 3}

In [11]:
import torch
from torch import nn
class DeepseekV2RotaryEmbedding(nn.Module):
    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
        super().__init__()

        self.dim = dim
        self.max_position_embeddings = max_position_embeddings
        self.base = base
        inv_freq = 1.0 / (
            self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
        )
        self.register_buffer("inv_freq", inv_freq, persistent=False)

        # Build here to make `torch.jit.trace` work.
        self._set_cos_sin_cache(
            seq_len=max_position_embeddings,
            device=self.inv_freq.device,
            dtype=torch.get_default_dtype(),
        )
        self.max_seq_len_cached = None

    def _set_cos_sin_cache(self, seq_len, device, dtype):
        self.max_seq_len_cached = seq_len
        
        t = torch.arange(
            self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
        )

        freqs = torch.outer(t, self.inv_freq.to(t.device))
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        emb = torch.cat((freqs, freqs), dim=-1)
        print(emb[:2,:])
        self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
        self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)

    def forward(self, x, seq_len=None):
        # x: [bs, num_attention_heads, seq_len, head_size]
        if self.max_seq_len_cached is None or seq_len > self.max_seq_len_cached:
            self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)

        return (
            self.cos_cached[:seq_len].to(dtype=x.dtype),
            self.sin_cached[:seq_len].to(dtype=x.dtype),
        )
model = DeepseekV2RotaryEmbedding(dim=10, device="cuda")
x = torch.randn((1,1,2,10), device="cuda")
model(x,seq_len = 2)

tensor([[0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
        [1.0000e+00, 1.5849e-01, 2.5119e-02, 3.9811e-03, 6.3096e-04, 1.0000e+00,
         1.5849e-01, 2.5119e-02, 3.9811e-03, 6.3096e-04]], device='cuda:0')
tensor([[0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
        [1.0000e+00, 1.5849e-01, 2.5119e-02, 3.9811e-03, 6.3096e-04, 1.0000e+00,
         1.5849e-01, 2.5119e-02, 3.9811e-03, 6.3096e-04]], device='cuda:0')


(tensor([[1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
          1.0000],
         [0.5403, 0.9875, 0.9997, 1.0000, 1.0000, 0.5403, 0.9875, 0.9997, 1.0000,
          1.0000]], device='cuda:0'),
 tensor([[0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
          0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
         [8.4147e-01, 1.5783e-01, 2.5116e-02, 3.9811e-03, 6.3096e-04, 8.4147e-01,
          1.5783e-01, 2.5116e-02, 3.9811e-03, 6.3096e-04]], device='cuda:0'))