diff --git a/tokenizer/rwkv_tokenizer.py b/tokenizer/rwkv_tokenizer.py index a445950b..3ea0b489 100644 --- a/tokenizer/rwkv_tokenizer.py +++ b/tokenizer/rwkv_tokenizer.py @@ -217,14 +217,15 @@ def printTokens(self, tokens): # Tokenizer #4 (fast) https://github.com/LoganDark ######################################################################################################## +from typing import Generator from ast import literal_eval class FastTokenizer: __slots__ = ('tok2val', 'tok2len', 'root') def __init__(self, file_name): - self.tok2val = [b''] * 65536 - self.tok2len = [0] * 65536 + self.tok2val = {} + self.tok2len = {} self.root = {} with open(file_name, 'rt', encoding = 'utf-8') as file: @@ -255,7 +256,7 @@ def next_token(self, src: bytes) -> int: break return last_token - def encode_bytes(self, src: bytes) -> list[int]: + def encode_bytes(self, src: bytes) -> Generator[int, None, None]: start, stop = 0, len(src) while start < stop: last_token, last = None, self.root @@ -272,9 +273,9 @@ def encode_bytes(self, src: bytes) -> list[int]: else: break def decode_bytes(self, tokens: list[int]) -> bytes: - return b''.join(map(self.tok2val.__getitem__, tokens)) + return b''.join(map(self.tok2val.get, tokens)) - def encode(self, src: str) -> list[int]: + def encode(self, src: str) -> Generator[int, None, None]: return self.encode_bytes(src.encode('utf-8')) def decode(self, tokens: list[int]) -> str: