Skip to content

Commit

Permalink
Hint Generator type
Browse files Browse the repository at this point in the history
  • Loading branch information
LoganDark committed Jun 5, 2023
1 parent 8ccb10a commit c01f3fb
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions tokenizer/rwkv_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,14 +217,15 @@ def printTokens(self, tokens):
# Tokenizer #4 (fast) https://github.com/LoganDark
########################################################################################################

from typing import Generator
from ast import literal_eval

class FastTokenizer:
__slots__ = ('tok2val', 'tok2len', 'root')

def __init__(self, file_name):
self.tok2val = [b''] * 65536
self.tok2len = [0] * 65536
self.tok2val = {}
self.tok2len = {}
self.root = {}

with open(file_name, 'rt', encoding = 'utf-8') as file:
Expand Down Expand Up @@ -255,7 +256,7 @@ def next_token(self, src: bytes) -> int:
break
return last_token

def encode_bytes(self, src: bytes) -> list[int]:
def encode_bytes(self, src: bytes) -> Generator[int, None, None]:
start, stop = 0, len(src)
while start < stop:
last_token, last = None, self.root
Expand All @@ -272,9 +273,9 @@ def encode_bytes(self, src: bytes) -> list[int]:
else: break

def decode_bytes(self, tokens: list[int]) -> bytes:
return b''.join(map(self.tok2val.__getitem__, tokens))
return b''.join(map(self.tok2val.get, tokens))

def encode(self, src: str) -> list[int]:
def encode(self, src: str) -> Generator[int, None, None]:
return self.encode_bytes(src.encode('utf-8'))

def decode(self, tokens: list[int]) -> str:
Expand Down

0 comments on commit c01f3fb

Please sign in to comment.