In [None]:
from collections import defaultdict
import regex as re
special_tokens = ["<|endoftext|>"]

text = "the cat ate"
vocab = {0: b' ', 1: b'a', 2:b'c', 3: b'e', 4: b'h', 5: b't', 6: b'th', 7: b' c', 8: b' a', 9: b'the', 10: b' at'}
merges = [(b't', b'h'), (b' ', b'c'), (b' ', b'a'), (b'th', b'e'), (b' a', b't')]


def word_to_bytes_tuple(word: str):
    return tuple(bytes([x]) for x in word.encode("utf-8"))
PAT = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
chunks = re.split("|".join(map(re.escape, special_tokens)), text)

tokens = []
for chunk in chunks:
    for m in re.finditer(PAT, chunk):
        word = m.group(0)
        tokens.append(word_to_bytes_tuple(word))

print(tokens)


def bpe_encode(tokens: list[tuple[bytes, bytes]], merges: list[tuple[bytes, bytes]], vocab: dict) -> list[int]:
    vocab_reverse = {v: k for k, v in vocab.items()}
    # Step 1: 字符串编码成字节序列
    # tokens = [bytes([b]) for b in text.encode('utf-8')]
    print("to encode: ", tokens)
    while True:
        pairs = [(tokens[i], tokens[i+1]) for i in range(len(tokens) - 1)]
        # print(f"pairs: {pairs}")
        # 找所有可合并对及其优先级
        candidate_pairs = [(pair, merges.index(pair)) for pair in pairs if pair in merges]
        # print(f"candidate_pairs: {candidate_pairs}")
        if not candidate_pairs:
            break
        # 找优先级最高的 pair (优先级最低的数字)
        best_pair = min(candidate_pairs, key=lambda x: x[1])[0]
        
        new_tokens = []
        i = 0
        while i < len(tokens):
            # 如果匹配最佳 pair，则合并
            if i < len(tokens) - 1 and (tokens[i], tokens[i+1]) == best_pair:
                new_tokens.append(tokens[i] + tokens[i+1])
                i += 2
            else:
                new_tokens.append(tokens[i])
                i += 1
        tokens = new_tokens
    print("after encode: ", tokens)
    # 转换 tokens 为 vocab ID
    token_ids = [vocab_reverse[t] for t in tokens]
    return token_ids

# print(bpe_encode(tokens[0], merges, vocab))
for token in tokens:
    print(bpe_encode(token, merges, vocab))


In [None]:
import regex as re
from collections import defaultdict
from typing import Iterable, Iterator

def word_to_bytes_tuple(word: str):
    return tuple(bytes([x]) for x in word.encode("utf-8"))

class BPE_Tokenizer:
    def __init__(self, vocab, merges, special_tokens=None) -> None:
        self.vocab = vocab
        self.merges = merges
        self.special_tokens = special_tokens or []

        self.vocab_reverse = {token: token_id for token_id, token in vocab.items()}

    def from_files(cls, vocab_filepath, merges_filepath, special_tokens=None) :
        # Class method that constructs and return a Tokenizer from a serialized vocabulary and list of merges
        # (in the same format that your BPE training code output) and (optionally) a list of special
        # tokens. This method should accept the following additional parameters:
        pass

    def _apply_merges(self, tokens: list[bytes]) -> list[bytes]:
        if not self.merges:
            return tokens
        
        while True:
            pairs = [(tokens[i], tokens[i+1]) for i in range(len(tokens) - 1)]
            # 找所有可合并对及其优先级
            candidate_pairs = [(pair, self.merges.index(pair)) for pair in pairs if pair in self.merges]
            if not candidate_pairs:
                break
            # 找优先级最高的 pair (优先级最低的数字)
            best_pair = min(candidate_pairs, key=lambda x: x[1])[0]
            
            new_tokens = []
            i = 0
            while i < len(tokens):
                # 如果匹配最佳 pair，则合并
                if i < len(tokens) - 1 and (tokens[i], tokens[i+1]) == best_pair:
                    new_tokens.append(tokens[i] + tokens[i+1])
                    i += 2
                else:
                    new_tokens.append(tokens[i])
                    i += 1
            tokens = new_tokens
        return tuple(tokens)

    def from_files(cls, vocab_filepath, merges_filepath, special_tokens=None):
        # Class method that constructs and return a Tokenizer from a serialized vocabulary and list of merges
        # (in the same format that your BPE training code output) and (optionally) a list of special
        # tokens. This method should accept the following additional parameters:
        # vocab_filepath: str
        # merges_filepath: str
        # special_tokens: list[str] | None = None
        pass
    
    def encode(self, text) -> list[int]:
        PAT = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
        chunks = re.split("|".join(map(re.escape, self.special_tokens)), text)
        tokens_list = []
        for chunk in chunks:
            for m in re.finditer(PAT, chunk):
                word = m.group(0)
                tokens_list.append(word_to_bytes_tuple(word))

        tokens_id = []
        for tokens in tokens_list:
            merged_tokens = self._apply_merges(tokens)
            tokens_id += [self.vocab_reverse[t] for t in merged_tokens]
        return tokens_id

    def encode_iterable(self, iterable: Iterable[str]) -> Iterator[int]:
        # Given an iterable of strings (e.g., a Python file handle), return a generator that lazily yields token IDs. 
        # This is required for memory-efficient tokenization of large files that we cannot directly load into memory.
        for text in iterable:
            yield self.encode(text)
    
    def decode(self, ids: list[int]) -> str:
        #Decode a sequence of token IDs into text.
        if len(ids) == 0:
            return ""
        text = ""
        token_li = []
        for token_id in ids:
            token = self.vocab[token_id]
            token_li.append(token)
            if token == b' ':
                text += ' '
            else:
                text += token.decode('utf-8')
        print(token_li)
        return text
    

text = "the cat ate"
vocab = {0: b' ', 1: b'a', 2:b'c', 3: b'e', 4: b'h', 5: b't', 6: b'th', 7: b' c', 8: b' a', 9: b'the', 10: b' at'}
merges = [(b't', b'h'), (b' ', b'c'), (b' ', b'a'), (b'th', b'e'), (b' a', b't')]
tokenizer = BPE_Tokenizer(vocab, merges, special_tokens=["<|endoftext|>"])

# print(tokenizer.encode(text))
encoded_ids = tokenizer.encode(text)
print(encoded_ids, type(encoded_ids))
decode_str = tokenizer.decode(encoded_ids)
print(decode_str, type(decode_str))


In [7]:
import torch

h, w = 32, 16
print(f"h: {h}, w: {w}, h*w: {h*w}")
channels_last = torch.randn(64, h, w, 3) # (batch, height, width, channel)
B = torch.randn(h*w, h*w)
## Rearrange an image tensor for mixing across all pixels
channels_last_flat = channels_last.view(
-1, channels_last.size(1) * channels_last.size(2), channels_last.size(3)
)
print(f"after view: {channels_last_flat.shape}")
channels_first_flat = channels_last_flat.transpose(1, 2)
print(f"after transpose: {channels_first_flat.shape}")
channels_first_flat_transformed = channels_first_flat @ B.T
print(f"B: {B.shape}, B.T: {B.T.shape}")
print(f"after transform: {channels_first_flat_transformed.shape}")
channels_last_flat_transformed = channels_first_flat_transformed.transpose(1, 2)
print(f"after transpose: {channels_last_flat_transformed.shape}")
channels_last_transformed = channels_last_flat_transformed.view(*channels_last.shape)
print(f"after view: {channels_last_transformed.shape}")

h: 32, w: 16, h*w: 512
after view: torch.Size([64, 512, 3])
after transpose: torch.Size([64, 3, 512])
B: torch.Size([512, 512]), B.T: torch.Size([512, 512])
after transform: torch.Size([64, 3, 512])
after transpose: torch.Size([64, 512, 3])
after view: torch.Size([64, 32, 16, 3])


In [27]:
from einops import rearrange, reduce, repeat, einsum
print(f"channels_last shape {channels_last.shape}")

channels_first = rearrange(channels_last, 'b h w c -> b c (h w)')
print(f"channels_first shape {channels_first.shape}")

channels_first_transformed = einsum(
    channels_first, B,
    "b c ppp_inin, ppp_out ppp_inin -> b c ppp_out"
)
print(f"channels_first_transformed shape {channels_first_transformed.shape}")
channels_last_transformed = rearrange(
    channels_first_transformed, 
    "b c (hhh w) -> b hhh w c", hhh=h, w=w
)
print(f"channels_last_transformed shape {channels_last_transformed.shape}")

channels_last shape torch.Size([64, 32, 16, 3])
channels_first shape torch.Size([64, 3, 512])
channels_first_transformed shape torch.Size([64, 3, 512])
channels_last_transformed shape torch.Size([64, 32, 16, 3])
