From c01f3fbcc204f39ce06abee77551b59e83e1d92e Mon Sep 17 00:00:00 2001 From: LoganDark Date: Sun, 4 Jun 2023 16:53:31 -0700 Subject: [PATCH] Hint Generator type --- tokenizer/rwkv_tokenizer.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/tokenizer/rwkv_tokenizer.py b/tokenizer/rwkv_tokenizer.py index a445950b..3ea0b489 100644 --- a/tokenizer/rwkv_tokenizer.py +++ b/tokenizer/rwkv_tokenizer.py @@ -217,14 +217,15 @@ def printTokens(self, tokens): # Tokenizer #4 (fast) https://github.com/LoganDark ######################################################################################################## +from typing import Generator from ast import literal_eval class FastTokenizer: __slots__ = ('tok2val', 'tok2len', 'root') def __init__(self, file_name): - self.tok2val = [b''] * 65536 - self.tok2len = [0] * 65536 + self.tok2val = {} + self.tok2len = {} self.root = {} with open(file_name, 'rt', encoding = 'utf-8') as file: @@ -255,7 +256,7 @@ def next_token(self, src: bytes) -> int: break return last_token - def encode_bytes(self, src: bytes) -> list[int]: + def encode_bytes(self, src: bytes) -> Generator[int, None, None]: start, stop = 0, len(src) while start < stop: last_token, last = None, self.root @@ -272,9 +273,9 @@ def encode_bytes(self, src: bytes) -> list[int]: else: break def decode_bytes(self, tokens: list[int]) -> bytes: - return b''.join(map(self.tok2val.__getitem__, tokens)) + return b''.join(map(self.tok2val.get, tokens)) - def encode(self, src: str) -> list[int]: + def encode(self, src: str) -> Generator[int, None, None]: return self.encode_bytes(src.encode('utf-8')) def decode(self, tokens: list[int]) -> str: