In [2]:
!pip install datasets
!pip install colorama



In [3]:
import pickle
import sys
from collections import Counter
from itertools import chain

import colorama
import regex as re
from datasets import load_dataset


In [4]:
def most_common(seq: iter):
    counts = Counter(seq)
    return counts.most_common(1)[0][0] if counts else None


def merge(seq: list[int], pair: tuple[int, int], symbol: int) -> list[int]:
    new_seq = []
    i = 0
    while i < len(seq) - 1:
        if (seq[i], seq[i + 1]) == pair:
            new_seq.append(symbol)
            i += 2
        else:
            new_seq.append(seq[i])
            i += 1
    if i == len(seq) - 1:
        new_seq.append(seq[-1])
    return new_seq


In [5]:
class BasicTokenizer:
    def __init__(self):
        self.vocab = {b: bytes([b]) for b in range(256)}
        self.merges = []

    def train(self, text: str, vocab_size: int) -> None:
        sequence = bytes(text, "utf-8")
        new_symbol = 256
        seq = [int(b) for b in sequence]
        merges = dict()

        while len(self.vocab) < vocab_size:
            pairs = list(zip(seq, seq[1:]))
            if not pairs:
                break
            pair = most_common(pairs)
            merges[pair] = new_symbol
            self.vocab[new_symbol] = self.vocab[pair[0]] + self.vocab[pair[1]]

            seq = merge(seq, pair, new_symbol)
            new_symbol += 1

        self.merges = sorted((symbol, pair) for pair, symbol in merges.items())

    def encode(self, text: str) -> list[int]:
        sequence = bytes(text, "utf-8")
        seq = [int(b) for b in sequence]
        for symbol, pair_from_train in self.merges:
            seq = merge(seq, pair_from_train, symbol)
        return seq

    def decode(self, ids: list[int]) -> str:
        return "".join(self.vocab[id].decode("utf-8") for id in ids)

In [6]:
def merge_with_cache(
    chunks: list[list[int]], pair: tuple[int, int], symbol: int
) -> None:
    cache = dict()
    for i in range(len(chunks)):
        tuple_chunk = tuple(chunks[i])
        if tuple_chunk in cache:
            chunks[i] = cache[tuple_chunk]
        else:
            res = merge(chunks[i], pair, symbol)
            cache[tuple_chunk] = res
            chunks[i] = res


In [7]:
GPT4_SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""

EOF = "<|endoftext|>"

python_common_builtins = [
    "abs",
    "all",
    "any",
    "bin",
    "bool",
    "bytes",
    "callable",
    "chr",
    "dict",
    "dir",
    "enumerate",
    "eval",
    "exec",
    "exit",
    "filter",
    "float",
    "hash",
    "hex",
    "id",
    "input",
    "int",
    "iter",
    "len",
    "list",
    "map",
    "max",
    "min",
    "next",
    "object",
    "open",
    "ord",
    "pow",
    "print",
    "range",
    "repr",
    "reversed",
    "sorted",
    "str",
    "sum",
    "super",
    "tuple",
    "type",
    "zip",
]

python_keywords = [
    "False",
    "None",
    "True",
    "and",
    "as",
    "assert",
    "async",
    "await",
    "break",
    "class",
    "continue",
    "def",
    "del",
    "elif",
    "else",
    "except",
    "finally",
    "for",
    "from",
    "global",
    "if",
    "import",
    "in",
    "is",
    "lambda",
    "nonlocal",
    "not",
    "or",
    "pass",
    "raise",
    "return",
    "try",
    "while",
    "with",
    "yield",
]

python_operations = [
    "+",
    "-",
    "*",
    "/",
    "%",
    "**",
    "//",
    "&",
    "|",
    "^",
    "~",
    "<<",
    ">>",
    "=",
    "+=",
    "-=",
    "*=",
    "/=",
    "%=",
    "**=",
    "//=",
    "&=",
    "|=",
    "^=",
    "~=",
    "<<=",
    ">>=",
]

colors = [
    colorama.Fore.RED,
    colorama.Fore.GREEN,
    colorama.Fore.YELLOW,
    colorama.Fore.BLUE,
    colorama.Fore.MAGENTA,
    colorama.Fore.CYAN,
    colorama.Fore.WHITE,
]

In [8]:
def build_special_symbols() -> set[str]:
    res = set()
    res.add(EOF)
    res.update(python_keywords)
    res.update(python_common_builtins)
    res.update(" " + kw for kw in python_keywords)
    res.update(" " + kw for kw in python_common_builtins)
    res.update(" " + kw for kw in python_operations)
    return res

In [9]:
class Tokenizer:
    def __init__(self, special_tokens: set[str] = set()):
        self.symbol = 256
        self.special_tokens = dict()
        self.vocab = {b: bytes([b]) for b in range(self.symbol)}
        for special_token in special_tokens:
            self.vocab[self.symbol] = bytes(special_token, "utf-8")
            self.special_tokens[special_token] = self.symbol
            self.symbol += 1
        self.merges = []
        self.pattern = self.compile_pattern()

    def compile_pattern(self):
        if not self.special_tokens:
            return re.compile(GPT4_SPLIT_PATTERN)
        else:
            escaped_tokens = [re.escape(t) for t in self.special_tokens]
            return re.compile("|".join(escaped_tokens) + r"|" + GPT4_SPLIT_PATTERN)

    def train(self, text: str, vocab_size: int, progress: bool = False) -> None:
        merges = dict()
        chunks = self.split_into_chunks(text)

        while len(self.vocab) < vocab_size:
            if progress:
                print(f"{len(self.vocab)}/{vocab_size}")

            pair = most_common(
                chain.from_iterable(
                    zip(chunk, chunk[1:]) for chunk in chunks if len(chunk) > 1
                )
            )
            if not pair:
                break

            new_token = self.vocab[pair[0]] + self.vocab[pair[1]]
            if new_token in self.special_tokens:
                symbol = self.special_tokens[new_token]
                merges[pair] = symbol
                merge_with_cache(chunks, pair, symbol)
                continue

            self.vocab[self.symbol] = new_token
            merges[pair] = self.symbol
            merge_with_cache(chunks, pair, self.symbol)
            self.symbol += 1

        self.merges = sorted((symbol, pair) for pair, symbol in merges.items())

    def split_into_chunks(self, text: str) -> list[list[int]]:
        chunks = []
        for chunk in re.findall(self.pattern, text):
            if not chunk:
                continue
            if chunk in self.special_tokens:
                chunks.append([self.special_tokens[chunk]])
            else:
                chunks.append(list(chunk.encode("utf-8")))
        return chunks

    def encode(self, text: str) -> list[int]:
        chunks = self.split_into_chunks(text)
        for symbol, pair_from_train in self.merges:
            merge_with_cache(chunks, pair_from_train, symbol)
        return list(chain.from_iterable(chunks))

    def decode(self, ids: list[int], colorize: bool = False, sep: str = "") -> str:
        if not colorize:
            return sep.join(self.vocab[id].decode("utf-8") for id in ids)
        else:
            colorama.init(autoreset=True)
            decoded = []
            for index, id in enumerate(ids):
                color = colors[index % len(colors)]
                token = self.vocab[id].decode("utf-8")
                decoded.append(color + token)
            return sep.join(decoded)


In [None]:
dataset = (
    load_dataset("iamtarun/python_code_instructions_18k_alpaca")["train"]
)

Generating train split:   0%|          | 0/18612 [00:00<?, ? examples/s]

Обучался он тут дольше чем на моем ноутбуке, поэтому файл с обученным токенизатором отдельно

In [None]:

n = 30000
path = "data.pkl"

text = EOF.join(example["output"] for example in dataset)
tokenizer = Tokenizer(build_special_symbols())

print("Training started")
tokenizer.train(text, n, progress=True)

print(f"Saving to {path}")
with open(path, "wb") as file:
    pickle.dump(tokenizer, file)

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
9551/30000
9552/30000
9553/30000
9554/30000
9555/30000
9556/30000
9557/30000
9558/30000
9559/30000
9560/30000
9561/30000
9562/30000
9563/30000
9564/30000
9565/30000
9566/30000
9567/30000
9568/30000
9569/30000
9570/30000
9571/30000
9572/30000
9573/30000
9574/30000
9575/30000
9576/30000
9577/30000
9578/30000
9579/30000
9580/30000
9581/30000
9582/30000
9583/30000
9584/30000
9585/30000
9586/30000
9587/30000
9588/30000
9589/30000
9590/30000
9591/30000
9592/30000
9593/30000
9594/30000
9595/30000
9596/30000
9597/30000
9598/30000
9599/30000
9600/30000
9601/30000
9602/30000
9603/30000
9604/30000
9605/30000
9606/30000
9607/30000
9608/30000
9609/30000
9610/30000
9611/30000
9612/30000
9613/30000
9614/30000
9615/30000
9616/30000
9617/30000
9618/30000
9619/30000
9620/30000
9621/30000
9622/30000
9623/30000
9624/30000
9625/30000
9626/30000
9627/30000
9628/30000
9629/30000
9630/30000
9631/30000
9632/30000
9633/30000
9634/30000
9635/30000


In [20]:
text = """
class Tokenizer:
    def __init__(self, special_tokens: set[str] = set()):
        self.symbol = 256
        self.special_tokens = dict()
        self.vocab = {b: bytes([b]) for b in range(self.symbol)}
        for special_token in special_tokens:
            self.vocab[self.symbol] = bytes(special_token, "utf-8")
            self.special_tokens[special_token] = self.symbol
            self.symbol += 1
        self.merges = []
        self.pattern = self.compile_pattern()

    def compile_pattern(self):
        if not self.special_tokens:
            return re.compile(GPT4_SPLIT_PATTERN)
        else:
            escaped_tokens = [re.escape(t) for t in self.special_tokens]
            return re.compile("|".join(escaped_tokens) + r"|" + GPT4_SPLIT_PATTERN)

    def train(self, text: str, vocab_size: int, progress: bool = False) -> None:
        merges = dict()
        chunks = self.split_into_chunks(text)

        while len(self.vocab) < vocab_size:
            if progress:
                print(f"{len(self.vocab)}/{vocab_size}")

            pair = most_common(
                chain.from_iterable(
                    zip(chunk, chunk[1:]) for chunk in chunks if len(chunk) > 1
                )
            )
            if not pair:
                break

            new_token = self.vocab[pair[0]] + self.vocab[pair[1]]
            if new_token in self.special_tokens:
                symbol = self.special_tokens[new_token]
                merges[pair] = symbol
                merge_with_cache(chunks, pair, symbol)
                continue

            self.vocab[self.symbol] = new_token
            merges[pair] = self.symbol
            merge_with_cache(chunks, pair, self.symbol)
            self.symbol += 1

        self.merges = sorted((symbol, pair) for pair, symbol in merges.items())

    def split_into_chunks(self, text: str) -> list[list[int]]:
        chunks = []
        for chunk in re.findall(self.pattern, text):
            if not chunk:
                continue
            if chunk in self.special_tokens:
                chunks.append([self.special_tokens[chunk]])
            else:
                chunks.append(list(chunk.encode("utf-8")))
        return chunks

    def encode(self, text: str) -> list[int]:
        chunks = self.split_into_chunks(text)
        for symbol, pair_from_train in self.merges:
            merge_with_cache(chunks, pair_from_train, symbol)
        return list(chain.from_iterable(chunks))

    def decode(self, ids: list[int], colorize: bool = False, sep: str = "") -> str:
        if not colorize:
            return sep.join(self.vocab[id].decode("utf-8") for id in ids)
        else:
            colorama.init(autoreset=True)
            decoded = []
            for index, id in enumerate(ids):
                color = colors[index % len(colors)]
                token = self.vocab[id].decode("utf-8")
                decoded.append(color + token)
            return sep.join(decoded)

"""

In [10]:
!rm -rf hw_ml4se
!git clone https://github.com/ReshetnikovPavel/hw_ml4se.git

Cloning into 'hw_ml4se'...
remote: Enumerating objects: 55, done.[K
remote: Counting objects:   1% (1/55)[Kremote: Counting objects:   3% (2/55)[Kremote: Counting objects:   5% (3/55)[Kremote: Counting objects:   7% (4/55)[Kremote: Counting objects:   9% (5/55)[Kremote: Counting objects:  10% (6/55)[Kremote: Counting objects:  12% (7/55)[Kremote: Counting objects:  14% (8/55)[Kremote: Counting objects:  16% (9/55)[Kremote: Counting objects:  18% (10/55)[Kremote: Counting objects:  20% (11/55)[Kremote: Counting objects:  21% (12/55)[Kremote: Counting objects:  23% (13/55)[Kremote: Counting objects:  25% (14/55)[Kremote: Counting objects:  27% (15/55)[Kremote: Counting objects:  29% (16/55)[Kremote: Counting objects:  30% (17/55)[Kremote: Counting objects:  32% (18/55)[Kremote: Counting objects:  34% (19/55)[Kremote: Counting objects:  36% (20/55)[Kremote: Counting objects:  38% (21/55)[Kremote: Counting objects:  40% (22/55)[Kremote: Counting

In [11]:
!cp 'hw_ml4se/hw3/tokenizer.py' 'tokenizer.py'

In [18]:
path = "hw_ml4se/hw3/tokenizer_python_code_instructions_18k_alpaca_30k_tokens.pkl"
with open(path, 'rb') as file:
    tokenizer: Tokenizer = pickle.load(file)


In [21]:
tokens = tokenizer.encode(text)
print(tokenizer.decode(tokens, colorize=True, sep="~"))


~class~ Tokenizer~:
~   ~ def~ __~in~it~__(~self~,~ special~_tokens~:~ set~[str~]~ =~ set~()):
~       ~ self~.symbol~ =~ ~256~
~       ~ self~.special~_tokens~ =~ dict~()
~       ~ self~.v~ocab~ =~ {~b~:~ bytes~([~b~])~ for~ b~ in~ range~(self~.symbol~)}
~       ~ for~ special~_token~ in~ special~_tokens~:
~           ~ self~.v~ocab~[self~.symbol~]~ =~ bytes~(special~_token~,~ "~utf~-~8~")
~           ~ self~.special~_tokens~[s~pecial~_token~]~ =~ self~.symbol~
~           ~ self~.symbol~ +~=~ ~1~
~       ~ self~.merge~s~ =~ []
~       ~ self~.p~attern~ =~ self~.compile~_pattern~()

~   ~ def~ compile~_pattern~(self~):
~       ~ if~ not~ self~.special~_tokens~:
~           ~ return~ re~.compile~(G~PT~4~_S~PL~IT~_PATTERN~)
~       ~ else~:
~           ~ esc~aped~_tokens~ =~ [~re~.escape~(t~)~ for~ t~ in~ self~.special~_tokens~]
~           ~ return~ re~.compile~("|~".~join~(~esc~aped~_tokens~)~ +~ r~"~|~"~ +~ GPT~4~_S~PL~IT~_PATTERN~)

~   ~ def~ train~(self~,~ text~:~ str~,~ vocab~_s