In [1]:
import sentencepiece as spm

In [2]:
processor = spm.SentencePieceProcessor(
    model_file="model/sentencepiece/tiny_shakespeare_67.model"
)


In [19]:
import sentencepiece as spm  # type: ignore


class Encoder:
    def encode(self, input: str) -> list[int]:
        raise NotImplementedError()

    def decode(self, input: list[int]) -> str:
        raise NotImplementedError()

    def vocab_size(self) -> int:
        raise NotImplementedError()


class CharEncoder(Encoder):
    _char_to_index: dict[str, int]
    _index_to_char: list[str]

    def __init__(self, chars: str) -> None:
        super().__init__()
        self._index_to_char = sorted(set(chars))
        self._char_to_index = {
            ch: i for i, ch in enumerate(self._index_to_char)
        }

    def encode(self, input: str) -> list[int]:
        return [self._char_to_index[ch] for ch in input]

    def decode(self, input: list[int]) -> str:
        return "".join([self._index_to_char[i] for i in input])

    def vocab_size(self) -> int:
        return len(self._index_to_char)


class SentencePieceEncoder(Encoder):
    _processor: spm.SentencePieceProcessor
    _eos_id: int

    def __init__(self, processor: spm.SentencePieceProcessor) -> None:
        super().__init__()
        self._processor = processor
        self._eos_id = processor.eos_id()

    def encode(self, input: str) -> list[int]:
        lines = input.split('\n')
        result = []
        for i, encoded_line in enumerate(self._processor.encode(lines)):
            if i > 0:
                result.append(self._eos_id)
            result.extend(encoded_line)
        return result

    def decode(self, input: list[int]) -> str:
        return '\n'.join(self._processor.decode(_split_list(input, self._eos_id)))

def _split_list(input: list[int], delimiter: int) -> list[list[int]]:
    offsets = [i for i, val in enumerate(input) if val == delimiter]
    last_offset = 0
    result = []
    for offset in offsets:
        result.append(input[last_offset:offset])
        last_offset = offset
    result.append(input[last_offset:])
    return result

In [7]:
[i for i, ch in enumerate('hello\nworld!\n') if ch == '\n']

[5, 12]

In [8]:
'hello\nworld!\n'.split('\n')

['hello', 'world!', '']

In [20]:
encoder = SentencePieceEncoder(processor)

In [22]:
encoder.encode('hello!\nworld!\n\n')

[3, 13, 4, 11, 11, 6, 45, 2, 3, 17, 6, 12, 11, 14, 45, 2, 2]

In [26]:
encoder.decode([])

''