In [73]:
import re
from collections import Counter
from typing import List, Tuple, Counter as _Counter
import sys
sys.path.append("./cs336_basics")
from collections import defaultdict, Counter
from pretokenization_example import find_chunk_boundaries

In [74]:
text = "low low low low low lower lower widest widest widest newest newest newest newest newest newest<|endoftext|>"
special = "<|endoftext|>"

def train_bpe_from_text(
    text: str,
    special: str = "<|endoftext|>",
    num_merges: int = 6
) -> Tuple[_Counter[Tuple[bytes, ...]], List[Tuple[bytes, bytes]]]:
    """
    在内存字符串上运行简易 BPE 训练。

    Args:
        text: 包含 special token 的原始文本。
        special: 用于拆分的特殊终止符，始终作为独立 token 处理。
        num_merges: 执行多少次 BPE 合并。

    Returns:
        vocab: Counter{ tuple(bytes,…): freq }
        merges: List of (bytes, bytes) merge operations in order.
    """
    # 1) 按 special 拆分，并保留分隔符自身
    parts = re.split(f"({re.escape(special)})", text)

    # 2) 构建“词”列表：special 独立，其它按空白拆
    pre_text: List[str] = []
    for part in parts:
        if not part:
            continue
        if part == special:
            pre_text.append(part)
        else:
            pre_text.extend(part.split())

    # 3) 将每个词编码为 tuple(bytes,…)
    def word_to_bytes_tokens(w: str) -> Tuple[bytes, ...]:
        return tuple(w.encode("utf-8"))

    byte_tokens = [word_to_bytes_tokens(w) for w in pre_text]

    # 4) 初始化 vocab counter
    vocab: _Counter[Tuple[bytes, ...]] = Counter(byte_tokens)

    # 5) 辅助函数：统计所有相邻 byte-pair 频次
    def get_pairs(v: _Counter[Tuple[bytes, ...]]) -> Counter:
        pairs = Counter()
        for seq, freq in v.items():
            for i in range(len(seq) - 1):
                pairs[(seq[i], seq[i+1])] += freq
        return pairs

    # 6) 辅助函数：在 vocab 中合并指定 pair
    def merge_vocab(pair: Tuple[bytes, bytes], v: _Counter[Tuple[bytes, ...]]) -> _Counter:
        a, b = pair
        merged = a + b
        new_v = Counter()
        for seq, freq in v.items():
            i = 0
            new_seq = []
            while i < len(seq):
                if i < len(seq)-1 and (seq[i], seq[i+1]) == pair:
                    new_seq.append(merged)
                    i += 2
                else:
                    new_seq.append(seq[i])
                    i += 1
            new_v[tuple(new_seq)] += freq
        return new_v

    # 7) 执行 num_merges 次 BPE 合并
    merges: List[Tuple[bytes, bytes]] = []
    for _ in range(num_merges):
        pairs = get_pairs(vocab)
        if not pairs:
            break
        max_freq = max(pairs.values())
        # 频次最高、按字节序选最大的 pair
        best = max(p for p, f in pairs.items() if f == max_freq)
        merges.append(best)
        vocab = merge_vocab(best, vocab)

    return vocab, merges

train_bpe_from_text(
    text,
    special,
    num_merges=6
)

(Counter({(211, 451): 6,
          (338,): 5,
          (119, 105, 100, 332): 3,
          (338, 101, 114): 2,
          (60, 124, 101, 110, 100, 111, 102, 116, 101, 120, 116, 124, 62): 1}),
 [(115, 116), (101, 231), (111, 119), (108, 230), (119, 332), (110, 101)])

In [75]:
import os
from typing import BinaryIO

def find_chunk_boundaries(
    file: BinaryIO, 
    desired_num_chunks: int, 
    split_special_token: bytes
) -> list[int]:
    """
    Chunk the file into parts that can be counted independently.
    May return fewer chunks if the boundaries end up overlapping.
    """
    assert isinstance(split_special_token, bytes), (
        "Must represent special token as a bytestring"
    )

    # Get total file size in bytes
    file.seek(0, os.SEEK_END)
    file_size = file.tell()
    file.seek(0)

    chunk_size = file_size // desired_num_chunks

    # Initial guesses for chunk boundary locations, uniformly spaced
    # Chunks start on previous index, don't include last index
    chunk_boundaries = [i * chunk_size for i in range(desired_num_chunks + 1)]
    chunk_boundaries[-1] = file_size

    mini_chunk_size = 4096  # Read ahead by 4k bytes at a time

    for bi in range(1, len(chunk_boundaries) - 1):
        initial_position = chunk_boundaries[bi]
        file.seek(initial_position)  # Start at boundary guess
        while True:
            mini_chunk = file.read(mini_chunk_size)  # Read a mini chunk

            # If EOF, this boundary should be at the end of the file
            if mini_chunk == b"":
                chunk_boundaries[bi] = file_size
                break

            # Find the special token in the mini chunk
            found_at = mini_chunk.find(split_special_token)
            if found_at != -1:
                chunk_boundaries[bi] = initial_position + found_at
                break
            initial_position += mini_chunk_size

    # Make sure all boundaries are unique, but might be fewer than desired_num_chunks
    return sorted(set(chunk_boundaries))

In [76]:
#check cpu number
import multiprocessing
cpu_count = multiprocessing.cpu_count()

In [None]:
with open("../data/TinyStoriesV2-GPT4-valid.txt", "rb") as f:
    boundaries = find_chunk_boundaries(
        f, cpu_count, "<|endoftext|>".encode("utf-8"))
        
    # The following is a serial implementation, but you can parallelize this 
    # by sending each start/end pair to a set of processes.
    for start, end in zip(boundaries[:-1], boundaries[1:]):
        f.seek(start)
        chunk = f.read(end - start).decode("utf-8", errors="ignore")
        vocab, merges = train_bpe_from_text(
            chunk,
            special="<|endoftext|>",
            num_merges=6
        )

0 937806
937806 1875363
1875363 2813541
2813541 3751213
3751213 4688540
4688540 5625758
5625758 6563458
6563458 7501211
7501211 8438541
8438541 9376342
9376342 10314231
10314231 11252559
11252559 12189328
12189328 13126706
13126706 14064356
14064356 15002188
15002188 15939582
15939582 16877372
16877372 17815181
17815181 18752209
18752209 19690315
19690315 20627655
20627655 21565010
21565010 22502601


In [80]:
from llm.bpe import train_bpe

train_bpe(input_path="../data/TinyStoriesV2-GPT4-train.txt",vocab_size=10000)


(Counter({(b'.',): 44108480,
          (b',',): 23884516,
          (b'the',): 20829509,
          (b'and',): 19479154,
          (b'a',): 15064461,
          (b'to',): 14903918,
          (b'"',): 11964631,
          (b'was',): 10593249,
          (b'The',): 5928078,
          (b'They',): 5651745,
          (b'it',): 5140524,
          (b'He',): 4968347,
          (b'said',): 4370397,
          (b'Tim',): 4216979,
          (b'day',): 4216782,
          (b'with',): 4208805,
          (b"'",): 4032158,
          (b'She',): 3946351,
          (b'!',): 3860155,
          (b'in',): 3848606,
          (b'her',): 3840818,
          (b'his',): 3781413,
          (b'big',): 3430394,
          (b'he',): 3234807,
          (b'they',): 2948584,
          (b'had',): 2875874,
          (b'you',): 2840783,
          (b'I',): 2728804,
          (b'<|endoftext|>',): 2717699,
          (b'not',): 2691190,
          (b'on',): 2550471,
          (b'of',): 2541805,
          (b'happy',): 2529226,
       

In [1]:
from collections.abc import Callable, Iterable
from typing import Optional
import torch
import math

class SGD(torch.optim.Optimizer):
    def __init__(self, params, lr=1e-3):
        if lr < 0.0:
            raise ValueError(f"Invalid learning rate: {lr}")
        defaults = dict(lr=lr)
        super().__init__(params, defaults)

    def step(self, closure: Optional[Callable] = None):
        loss = None
        if closure is not None:
            loss = closure()
        for group in self.param_groups:
            lr = group['lr']
            for p in group['params']:
                if p.grad is None:
                    continue
                state = self.state[p]
                t = state.get('t', 0)
                grad = p.grad.data
                p.data -= lr / math.sqrt(t + 1) * grad
                state['t'] = t + 1
        return loss

In [7]:
weights = torch.nn.Parameter(5 * torch.randn((10, 10)))
opt = SGD([weights], lr=1e2)
for t in range(10):
    opt.zero_grad() # Reset the gradients for all learnable parameters.
    loss = (weights**2).mean() # Compute a scalar loss value.
    print(loss.cpu().item())
    loss.backward() # Run backward pass, which computes gradients.
    opt.step() # Run optimizer step.

26.740989685058594
26.740982055664062
4.588027000427246
0.10980181396007538
1.168943451344663e-16
1.3028599227883843e-18
4.3871882543965306e-20
2.613477953154858e-21
2.2420103453493473e-22
2.4911227111251733e-23
