In [13]:
class Word_Encoder(nn.Module): # Токенайзер + Эмбеддер для букв
    def __init__(self, alphabet, emb_size, max_word_size = 256):
        super().__init__()
        self.alphabet = list(alphabet)+["<pad>", "<stress>", "<unk>"] # буквы + спец токены: пустой, ударение и неизвестный символ
        self.emb_size = emb_size
        self.embeddings = nn.Embedding(len(self.alphabet), emb_size)
        self.pos_embeddings = nn.Embedding(max_word_size, emb_size)
        self.device = self.embeddings.device

class Word_Encoder(nn.Module):
    def __init__(self, alphabet, emb_size, max_word_size=256):
        super().__init__()
        self.alphabet = list(alphabet) + ["<pad>", "<stress>", "<unk>"]
        self.emb_size = emb_size
        self.max_word_size = max_word_size
        self.embeddings = nn.Embedding(len(self.alphabet), emb_size)
        self.pos_embeddings = nn.Embedding(max_word_size, emb_size)

        self.get_idx = {char: idx for idx, char in enumerate(self.alphabet)}
        self.pad_idx = self.get_idx["<pad>"]
        self.stress_idx = self.get_idx["<stress>"]
        self.unk_idx = self.get_idx["<unk>"]
        self.device = self.embeddings.weight.device

    def tokenize(self, text):
        if isinstance(text, str):
            text = [text]
        tokenized = []
        for word in text:
            word_idxs = []
            i = 0
            n = len(word)
            while i < n:
                if word[i] == "<" and i + 8 < n and word[i:i+8] == "<stress>":
                  word_idxs.append(self.stress_idx)
                  i += 8
                else:
                    char = word[i]
                    if char in self.get_idx:
                        word_idxs.append(self.get_idx[char])
                    else:
                        word_idxs.append(self.unk_idx)
                    i += 1

            tokenized.append(word_idxs)
        max_len = max(len(word) for word in tokenized)
        padded = []
        for word in tokenized:
            padded_word = word
            if len(word) < max_len:
                padded_word += [self.pad_idx] * (max_len - len(word))
            padded.append(padded_word)

        return torch.tensor(padded, device=self.device)

    def forward(self, x): # Не забыть проверить работу с батчами
        self.device = x.device
        n = x.shape[-1]
        pos = torch.arange(n).to(self.device)
        x = self.embeddings(x)+self.pos_embeddings(x)
        return x


In [16]:
encoder = Word_Encoder("абвгдеёжзийклмнопрстуфхцчшщъыьэюя", 128)
print(encoder.get_idx)

tokens1 = encoder.tokenize("привет")
print(tokens1)

tokens2 = encoder.tokenize(["прив<stress>ет", "м2ир"])
print(tokens2)

tokens3 = encoder.tokenize("уратыработаешь")
print(tokens3)

{'а': 0, 'б': 1, 'в': 2, 'г': 3, 'д': 4, 'е': 5, 'ё': 6, 'ж': 7, 'з': 8, 'и': 9, 'й': 10, 'к': 11, 'л': 12, 'м': 13, 'н': 14, 'о': 15, 'п': 16, 'р': 17, 'с': 18, 'т': 19, 'у': 20, 'ф': 21, 'х': 22, 'ц': 23, 'ч': 24, 'ш': 25, 'щ': 26, 'ъ': 27, 'ы': 28, 'ь': 29, 'э': 30, 'ю': 31, 'я': 32, '<pad>': 33, '<stress>': 34, '<unk>': 35}
tensor([[16, 17,  9,  2,  5, 19]])
tensor([[16, 17,  9,  2, 34,  5, 19],
        [13, 35,  9, 17, 33, 33, 33]])
tensor([[20, 17,  0, 19, 28, 17,  0,  1, 15, 19,  0,  5, 25, 29]])
