In [1]:
from tokenizer import train_mopiece, MOPiece
from word_autoencoder import WordEncoder, WordDecoder
import torch as pt
from torch import nn
import pandas as pl
import regex as re
from tqdm.notebook import trange, tqdm

device = ('cuda' if pt.cuda.is_available() else 'mps' if pt.backends.mps.is_available() else 'cpu')

In [2]:
suffixes = pl.read_csv('data/morphynet/suffixes.csv')['suffix']
prefixes = pl.read_csv('data/morphynet/prefixes.csv')['prefix']

train_mopiece('____tokenizer', ['data/text/bible.txt'], prefixes, suffixes, 3000, spm_model_type='bpe')

mopiece = MOPiece('____tokenizer')

In [3]:
words = set()
reg = re.compile(r'([^\p{L}\p{M}\p{N}\s]+|\s)')
with open('data/text/bible.txt', mode='r', encoding='utf8') as file:
            for line in file.readlines():
                for word in reg.split(line):
                    if word == '' or word == ' ':
                        continue
                    word = word.lower()
                    words.add(word)

In [4]:
class WordDataset(pt.utils.data.Dataset):
    def __init__(self, words, mopiece, device=None):
        super().__init__()
        bos_id, eos_id = mopiece.bos_id(), mopiece.eos_id()
        self.words = []
        for word in words:
            prefix_ids, spm_ids, suffix_ids = mopiece.encode_word(word)
            self.words.append((
                pt.tensor([bos_id] + prefix_ids + [eos_id], dtype=pt.long, device=device), 
                pt.tensor([bos_id] + spm_ids    + [eos_id], dtype=pt.long, device=device), 
                pt.tensor([bos_id] + suffix_ids + [eos_id], dtype=pt.long, device=device)
            ))
        
    def __len__(self):
        return len(self.words)
    
    def __getitem__(self, index):
        return self.words[index]
    
class SequenceCrossEntropyLoss(nn.CrossEntropyLoss):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs, reduction='sum')

    def forward(self, prefix_logits, spm_logits, suffix_logits, prefix_ids, spm_ids, suffix_ids):
        loss = 0.
        for prefix_logits_seq, spm_logits_seq, suffix_logits_seq, prefix_ids_seq, spm_ids_seq, suffix_ids_seq in zip(prefix_logits, spm_logits, suffix_logits, prefix_ids, spm_ids, suffix_ids):
            loss += (super().forward(prefix_logits_seq, prefix_ids_seq) + super().forward(spm_logits_seq, spm_ids_seq) + super().forward(suffix_logits_seq, suffix_ids_seq)) / (len(prefix_ids_seq) + len(spm_ids_seq) + len(suffix_ids_seq))
        return loss / prefix_logits.shape[0]
    
dset = WordDataset(words, mopiece, device=device)

def collate_fn(batch):
    return (pt.nn.utils.rnn.pad_sequence([prefix for prefix, spm, suffix in batch], batch_first=True, padding_value=mopiece.pad_id(), padding_side='right'),
            pt.nn.utils.rnn.pad_sequence([spm    for prefix, spm, suffix in batch], batch_first=True, padding_value=mopiece.pad_id(), padding_side='right'),
            pt.nn.utils.rnn.pad_sequence([suffix for prefix, spm, suffix in batch], batch_first=True, padding_value=mopiece.pad_id(), padding_side='right'))

In [None]:
encoder = WordEncoder(mopiece.vocab_size(), mopiece.pad_id(), 256, ffn_hidden_dim=512, expansion_factor=4, spm_layers=6).to(device)
decoder = WordDecoder(mopiece.vocab_size(), mopiece.pad_id(), mopiece.bos_id(), mopiece.eos_id(), 256, spm_dim=512, suffix_dim=256, prefix_dim=256, num_layers=8, expansion_factor=2).to(device)

In [25]:
# criterion = SequenceCrossEntropyLoss(ignore_index=mopiece.pad_id())
xent = nn.CrossEntropyLoss(ignore_index=mopiece.pad_id(), label_smoothing=.1)
criterion = lambda logits, labels: xent(logits.flatten(end_dim=-2), labels.flatten())

epochs = 50
batch_size = 128

loader = pt.utils.data.DataLoader(dset, batch_size, shuffle=True, drop_last=True, collate_fn=collate_fn)

optim = pt.optim.AdamW([
    {"params": encoder.parameters()},
    {"params": decoder.parameters()}
], weight_decay=0.01, lr=3e-4)
lr_scheduler = pt.optim.lr_scheduler.CosineAnnealingLR(optim, epochs)

encoder.train()
decoder.train()

log_perplexities = []
for epoch in (pbar := trange(epochs, desc='Epoch')):
    log_perplexity_sum = 0
    for prefix_ids, spm_ids, suffix_ids in tqdm(loader, desc='Training', leave=False):
        optim.zero_grad()

        embedding = encoder(prefix_ids, spm_ids, suffix_ids)

        prefix_logits, spm_logits, suffix_logits = decoder(prefix_ids[..., :-1], spm_ids[..., :-1], suffix_ids[..., :-1], embedding)

        loss = criterion(prefix_logits, prefix_ids[..., 1:]) + criterion(spm_logits, spm_ids[..., 1:]) + criterion(suffix_logits, suffix_ids[..., 1:])

        loss.backward()
        optim.step()

        log_perplexity_sum += loss.item()

    log_perplexity = log_perplexity_sum / len(loader)
    log_perplexities.append(log_perplexity)
    pbar.set_postfix_str(f'log-perplexity: {log_perplexity:.2f}')



Epoch:   0%|          | 0/50 [00:00<?, ?it/s]

Training:   0%|          | 0/101 [00:00<?, ?it/s]

Training:   0%|          | 0/101 [00:00<?, ?it/s]

Training:   0%|          | 0/101 [00:00<?, ?it/s]

Training:   0%|          | 0/101 [00:00<?, ?it/s]

Training:   0%|          | 0/101 [00:00<?, ?it/s]

Training:   0%|          | 0/101 [00:00<?, ?it/s]

Training:   0%|          | 0/101 [00:00<?, ?it/s]

Training:   0%|          | 0/101 [00:00<?, ?it/s]

Training:   0%|          | 0/101 [00:00<?, ?it/s]

Training:   0%|          | 0/101 [00:00<?, ?it/s]

Training:   0%|          | 0/101 [00:00<?, ?it/s]

Training:   0%|          | 0/101 [00:00<?, ?it/s]

Training:   0%|          | 0/101 [00:00<?, ?it/s]

Training:   0%|          | 0/101 [00:00<?, ?it/s]

Training:   0%|          | 0/101 [00:00<?, ?it/s]

Training:   0%|          | 0/101 [00:00<?, ?it/s]

Training:   0%|          | 0/101 [00:00<?, ?it/s]

Training:   0%|          | 0/101 [00:00<?, ?it/s]

Training:   0%|          | 0/101 [00:00<?, ?it/s]

Training:   0%|          | 0/101 [00:00<?, ?it/s]

Training:   0%|          | 0/101 [00:00<?, ?it/s]

Training:   0%|          | 0/101 [00:00<?, ?it/s]

Training:   0%|          | 0/101 [00:00<?, ?it/s]

Training:   0%|          | 0/101 [00:00<?, ?it/s]

Training:   0%|          | 0/101 [00:00<?, ?it/s]

Training:   0%|          | 0/101 [00:00<?, ?it/s]

Training:   0%|          | 0/101 [00:00<?, ?it/s]

Training:   0%|          | 0/101 [00:00<?, ?it/s]

Training:   0%|          | 0/101 [00:00<?, ?it/s]

Training:   0%|          | 0/101 [00:00<?, ?it/s]

Training:   0%|          | 0/101 [00:00<?, ?it/s]

Training:   0%|          | 0/101 [00:00<?, ?it/s]

Training:   0%|          | 0/101 [00:00<?, ?it/s]

Training:   0%|          | 0/101 [00:00<?, ?it/s]

Training:   0%|          | 0/101 [00:00<?, ?it/s]

Training:   0%|          | 0/101 [00:00<?, ?it/s]

Training:   0%|          | 0/101 [00:00<?, ?it/s]

Training:   0%|          | 0/101 [00:00<?, ?it/s]

Training:   0%|          | 0/101 [00:00<?, ?it/s]

Training:   0%|          | 0/101 [00:00<?, ?it/s]

Training:   0%|          | 0/101 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [6]:
for prefix_ids, spm_ids, suffix_ids in tqdm(loader, desc='Training', leave=False):
    optim.zero_grad()

    embedding = encoder(prefix_ids, spm_ids, suffix_ids)

    prefix_logits, spm_logits, suffix_logits = decoder(prefix_ids[..., :-1], spm_ids[..., :-1], suffix_ids[..., :-1], embedding)

    loss_p = criterion(prefix_logits, prefix_ids[..., 1:])
    loss_r = criterion(spm_logits, spm_ids[..., 1:])
    loss_s = criterion(suffix_logits, suffix_ids[..., 1:])
    loss = loss_p + loss_r + loss_s

    print(loss_p.item(), loss_r.item(), loss_s.item())

    loss.backward()
    optim.step()

    break

NameError: name 'loader' is not defined

In [37]:
pt.save(encoder.state_dict(), "____encoder.pt")
pt.save(decoder.state_dict(), "____decoder.pt")

In [38]:
encoder.load_state_dict(pt.load("____encoder.pt", weights_only=True))
decoder.load_state_dict(pt.load("____decoder.pt", weights_only=True))

<All keys matched successfully>

In [6]:
from random import sample

with pt.no_grad():
    encoder.eval()
    decoder.eval()
    bos_id, eos_id = mopiece.bos_id(), mopiece.eos_id()

    word = sample(list(words), 1)[0]
    prefix_ids, spm_ids, suffix_ids = map(lambda x: [bos_id] + x + [eos_id], mopiece.encode_word(word))
    print('input:')
    print(prefix_ids, spm_ids, suffix_ids)
    print(word)


    embedding = encoder(pt.tensor(prefix_ids, dtype=pt.long, device=device), pt.tensor(spm_ids, dtype=pt.long, device=device), pt.tensor(suffix_ids, dtype=pt.long, device=device))

    out_prefix_ids, out_spm_ids, out_suffix_ids = decoder.beam_search(embedding)
    print('\noutput:')
    print(out_prefix_ids, out_spm_ids, out_suffix_ids)
    print(mopiece.decode_word(out_prefix_ids, out_spm_ids, out_suffix_ids))

input:
[0, 1] [0, 1455, 1] [0, 116, 419, 1]
woundedst

output:
[0, 164, 1] [0, 1048, 686, 633, 814, 419, 2701, 766, 1506, 2387, 2903, 2813, 320, 1] [0, 200, 466, 227, 137, 274, 356, 16, 476, 50, 207, 259, 280, 231, 240, 388, 207, 259, 280, 156, 265, 277, 239, 198, 478, 11, 296, 298, 304, 62, 350, 261, 425, 88, 386, 222, 261, 425, 88, 386, 375, 483, 441, 261, 425, 88, 386, 222, 261, 425, 88, 386, 222, 261, 425, 88, 386, 222, 261, 425, 88, 386, 222, 227, 291]
colievemountadieorrowideamahrivfacjoelleeppla16icitiwateriserslogpersonadelicworkaryierlandlyticismitisaucierlandlyticferousletlongiteicideworthiacmetricmobilnautativeouslatriteencharyinglatriteencharypocalypsezillatylatriteencharyinglatriteencharyinglatriteencharyinglatriteencharyingisment


In [106]:
import torch as pt
from torch import nn
import torch.nn.functional as F

from common import Encoder, Decoder, RoPEMHSA, MultiheadCrossAttention, FFN


In [109]:
class EnhancedWordEncoder(nn.Module):
    def __init__(self, vocab_size, pad_id, embedding_dim=256, spm_dim=512,
                 suffix_dim=256, prefix_dim=256, spm_layers=6, spm_heads=16,
                 spm_rope_cache=128, ffn_hidden_dim=1024, expansion_factor=4, dropout=0.2):
        super().__init__()

        self.vocab_size = vocab_size
        self.pad_id = pad_id
        self.dropout = nn.Dropout(dropout)

        self.spm_embedding = nn.Embedding(vocab_size[1], spm_dim, padding_idx=pad_id)
        nn.init.normal_(self.spm_embedding.weight, mean=0.0, std=0.02)

        self.spm_encoder_layers = nn.ModuleList([
            nn.ModuleDict({
                'attn': RoPEMHSA(spm_dim, spm_dim // spm_heads, spm_dim // spm_heads,
                                spm_heads, dropout, spm_rope_cache),
                'ffn': FFN(spm_dim, ffn_hidden_dim, dropout),
                'norm1': nn.RMSNorm(spm_dim),
                'norm2': nn.RMSNorm(spm_dim)
            })
            for _ in range(spm_layers)
        ])

        self.spm_final_norm = nn.RMSNorm(spm_dim)

        self.spm_pooling_q = nn.Parameter(pt.randn(spm_heads, spm_dim // spm_heads) * 0.02)
        self.spm_pooling_k = nn.Linear(spm_dim, spm_dim)
        self.spm_pooling_v = nn.Linear(spm_dim, spm_dim)
        self.spm_pooling_proj = nn.Linear(spm_dim, embedding_dim)

        self.prefix_embedding = nn.Embedding(vocab_size[0], prefix_dim, padding_idx=pad_id)
        nn.init.normal_(self.prefix_embedding.weight, mean=0.0, std=0.02)

        self.prefix_encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model=prefix_dim,
                nhead=8,
                dim_feedforward=prefix_dim * 2,
                dropout=dropout,
                activation='gelu',
                batch_first=True,
                norm_first=True
            ),
            num_layers=3
        )

        self.prefix_pooling = nn.Sequential(
            nn.Linear(prefix_dim, embedding_dim),
            nn.LayerNorm(embedding_dim),
            nn.GELU()
        )

        self.suffix_embedding = nn.Embedding(vocab_size[2], suffix_dim, padding_idx=pad_id)
        nn.init.normal_(self.suffix_embedding.weight, mean=0.0, std=0.02)

        self.suffix_encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model=suffix_dim,
                nhead=8,
                dim_feedforward=suffix_dim * 2,
                dropout=dropout,
                activation='gelu',
                batch_first=True,
                norm_first=True
            ),
            num_layers=3
        )

        self.suffix_pooling = nn.Sequential(
            nn.Linear(suffix_dim, embedding_dim),
            nn.LayerNorm(embedding_dim),
            nn.GELU()
        )

        self.integration_net = nn.Sequential(
            nn.Linear(embedding_dim * 3, embedding_dim * 2),
            nn.LayerNorm(embedding_dim * 2),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(embedding_dim * 2, embedding_dim),
            nn.LayerNorm(embedding_dim)
        )

    def forward(self, prefix_ids, spm_ids, suffix_ids):
        spm_pad_mask = spm_ids != self.pad_id
        spm_attn_mask = spm_pad_mask.unsqueeze(-1) & spm_pad_mask.unsqueeze(-2)

        spm_emb = self.dropout(self.spm_embedding(spm_ids))

        for layer in self.spm_encoder_layers:
            residual = spm_emb
            spm_emb = layer['norm1'](spm_emb)
            spm_emb = residual + layer['attn'](spm_emb, spm_attn_mask)

            # FFN with residual
            residual = spm_emb
            spm_emb = layer['norm2'](spm_emb)
            spm_emb = residual + layer['ffn'](spm_emb)

        spm_emb = self.spm_final_norm(spm_emb)

        batch_shape = spm_emb.shape[:-2]
        L = spm_emb.shape[-2]

        spm_k = self.spm_pooling_k(spm_emb).view(*batch_shape, L, *self.spm_pooling_q.shape).transpose(-2, -3)
        spm_v = self.spm_pooling_v(spm_emb).view(*batch_shape, L, *self.spm_pooling_q.shape).transpose(-2, -3)

        q_expanded = self.spm_pooling_q.view(*(1,)*len(batch_shape), self.spm_pooling_q.shape[0], 1, self.spm_pooling_q.shape[1])

        spm_word_emb = self.spm_pooling_proj(
            F.scaled_dot_product_attention(
                q_expanded,
                spm_k,
                spm_v,
                spm_pad_mask.view(*batch_shape, 1, 1, L)
            ).flatten(-3)
        )

        prefix_pad_mask = prefix_ids != self.pad_id
        prefix_mask = ~prefix_pad_mask

        prefix_emb = self.dropout(self.prefix_embedding(prefix_ids))
        prefix_emb = self.prefix_encoder(prefix_emb, src_key_padding_mask=prefix_mask)

        prefix_mask_expanded = prefix_pad_mask.unsqueeze(-1).float()
        prefix_sum = (prefix_emb * prefix_mask_expanded).sum(dim=-2)
        prefix_count = prefix_mask_expanded.sum(dim=-2).clamp(min=1.0)
        prefix_pooled = prefix_sum / prefix_count

        prefix_word_emb = self.prefix_pooling(prefix_pooled)

        suffix_pad_mask = suffix_ids != self.pad_id
        suffix_mask = ~suffix_pad_mask

        suffix_emb = self.dropout(self.suffix_embedding(suffix_ids))
        suffix_emb = self.suffix_encoder(suffix_emb, src_key_padding_mask=suffix_mask)

        suffix_mask_expanded = suffix_pad_mask.unsqueeze(-1).float()
        suffix_sum = (suffix_emb * suffix_mask_expanded).sum(dim=-2)
        suffix_count = suffix_mask_expanded.sum(dim=-2).clamp(min=1.0)
        suffix_pooled = suffix_sum / suffix_count

        suffix_word_emb = self.suffix_pooling(suffix_pooled)

        word_emb = self.integration_net(
            pt.cat([spm_word_emb, prefix_word_emb, suffix_word_emb], dim=-1)
        )

        return word_emb

In [221]:
class EnhancedWordDecoder(nn.Module):
    def __init__(self, vocab_size, pad_id, bos_id, eos_id, embedding_dim=256,
                 spm_dim=512, suffix_dim=256, prefix_dim=256, num_layers=6,
                 num_heads=8, ffn_dim=1024, dropout=0.2, rope_cache=128):
        super().__init__()

        self.vocab_size = vocab_size
        self.pad_id = pad_id
        self.bos_id = bos_id
        self.eos_id = eos_id

        self.prefix_proj = nn.Linear(embedding_dim, prefix_dim)
        self.prefix_embedding = nn.Embedding(vocab_size[0], prefix_dim, padding_idx=pad_id)
        nn.init.normal_(self.prefix_embedding.weight, mean=0.0, std=0.02)

        self.prefix_decoder_layers = nn.ModuleList([
            nn.ModuleDict({
                'attn': RoPEMHSA(prefix_dim, prefix_dim // num_heads, prefix_dim // num_heads,
                                num_heads, dropout, rope_cache),
                'ffn': FFN(prefix_dim, ffn_dim, dropout),
                'norm1': nn.RMSNorm(prefix_dim),
                'norm2': nn.RMSNorm(prefix_dim)
            })
            for _ in range(num_layers)
        ])

        self.prefix_norm = nn.RMSNorm(prefix_dim)
        self.prefix_classifier = nn.Linear(prefix_dim, vocab_size[0])

        self.spm_proj = nn.Linear(embedding_dim, spm_dim)
        self.spm_embedding = nn.Embedding(vocab_size[1], spm_dim, padding_idx=pad_id)
        nn.init.normal_(self.spm_embedding.weight, mean=0.0, std=0.02)

        # self.spm_decoder_layers = nn.ModuleList([
        #     nn.ModuleDict({
        #         'attn': RoPEMHSA(spm_dim, spm_dim // num_heads, spm_dim // num_heads,
        #                         num_heads, dropout, rope_cache),
        #         'ffn': FFN(spm_dim, ffn_dim, dropout),
        #         'norm1': nn.RMSNorm(spm_dim),
        #         'norm2': nn.RMSNorm(spm_dim)
        #     })
        #     for _ in range(num_layers)
        # ])
        self.spm_decoder = Encoder(num_layers,
            lambda: RoPEMHSA(spm_dim, spm_dim // num_heads, spm_dim // num_heads, num_heads, dropout, rope_cache),
            lambda: FFN(spm_dim, ffn_dim, dropout),
            lambda: nn.RMSNorm(spm_dim)  
        )

        self.spm_norm = nn.RMSNorm(spm_dim)
        self.spm_classifier = nn.Linear(spm_dim, vocab_size[1])

        self.suffix_proj = nn.Linear(embedding_dim, suffix_dim)
        self.suffix_embedding = nn.Embedding(vocab_size[2], suffix_dim, padding_idx=pad_id)
        nn.init.normal_(self.suffix_embedding.weight, mean=0.0, std=0.02)

        self.suffix_decoder_layers = nn.ModuleList([
            nn.ModuleDict({
                'attn': RoPEMHSA(suffix_dim, suffix_dim // num_heads, suffix_dim // num_heads,
                                num_heads, dropout, rope_cache),
                'ffn': FFN(suffix_dim, ffn_dim, dropout),
                'norm1': nn.RMSNorm(suffix_dim),
                'norm2': nn.RMSNorm(suffix_dim)
            })
            for _ in range(num_layers)
        ])

        self.suffix_norm = nn.RMSNorm(suffix_dim)
        self.suffix_classifier = nn.Linear(suffix_dim, vocab_size[2])

        self.dropout = nn.Dropout(dropout)

    def _decode_sequence(self, input_emb, decoder_layers, final_norm, classifier):
        batch_shape = input_emb.shape[:-2]
        L = input_emb.shape[-2]

        mask = pt.tril(pt.ones((*batch_shape, L, L), device=input_emb.device, dtype=pt.bool))

        x = input_emb
        for layer in decoder_layers:
            residual = x
            x_norm = layer['norm1'](x)
            x = residual + layer['attn'](x_norm, mask)

            residual = x
            x_norm = layer['norm2'](x)
            x = residual + layer['ffn'](x_norm)

        x = final_norm(x)
        logits = classifier(x)

        return logits
    
    def _decode_sequence_exp(self, input_emb, decoder, final_norm, classifier):
        batch_shape = input_emb.shape[:-2]
        L = input_emb.shape[-2]

        mask = pt.tril(pt.ones((*batch_shape, L, L), device=input_emb.device, dtype=pt.bool))

        return classifier(final_norm(decoder(input_emb, mask)))


    def forward(self, prefix_ids, spm_ids, suffix_ids, embedding):
        embedding_proj = embedding.unsqueeze(1)

        prefix_emb_proj = self.prefix_proj(embedding_proj)
        prefix_emb = self.prefix_embedding(prefix_ids)
        prefix_input = pt.cat([prefix_emb_proj, prefix_emb], dim=1)

        prefix_logits = self._decode_sequence(
            prefix_input,
            self.prefix_decoder_layers,
            self.prefix_norm,
            self.prefix_classifier
        )[:, 1:, :]

        spm_emb_proj = self.spm_proj(embedding_proj)
        spm_emb = self.spm_embedding(spm_ids)
        spm_input = pt.cat([spm_emb_proj, spm_emb], dim=1)

        spm_logits = self._decode_sequence_exp(
            spm_input,
            self.spm_decoder,
            self.spm_norm,
            self.spm_classifier
        )[:, 1:, :]

        suffix_emb_proj = self.suffix_proj(embedding_proj)
        suffix_emb = self.suffix_embedding(suffix_ids)
        suffix_input = pt.cat([suffix_emb_proj, suffix_emb], dim=1)

        suffix_logits = self._decode_sequence(
            suffix_input,
            self.suffix_decoder_layers,
            self.suffix_norm,
            self.suffix_classifier
        )[:, 1:, :]

        return prefix_logits, spm_logits, suffix_logits

    @pt.inference_mode()
    def inference(self, embedding, max_len=64):
        self.eval()
        if isinstance(max_len, int):
            max_len = (max_len, max_len, max_len)

        device = embedding.device

        if embedding.dim() == 1:
            embedding = embedding.unsqueeze(0)
        embedding_proj = embedding.unsqueeze(1)

        prefix_ids = [self.bos_id]
        prefix_emb_proj = self.prefix_proj(embedding_proj)

        for _ in range(max_len[0]):
            curr_seq = pt.tensor([prefix_ids], device=device)
            curr_emb = self.prefix_embedding(curr_seq)
            curr_input = pt.cat([prefix_emb_proj, curr_emb], dim=1)

            logits = self._decode_sequence(
                curr_input,
                self.prefix_decoder_layers,
                self.prefix_norm,
                self.prefix_classifier
            )

            next_token = logits[0, -1].argmax(-1).item()
            prefix_ids.append(next_token)

            if next_token == self.eos_id:
                break

        spm_ids = [self.bos_id]
        spm_emb_proj = self.spm_proj(embedding_proj)

        for _ in range(max_len[1]):
            curr_seq = pt.tensor([spm_ids], device=device)
            curr_emb = self.spm_embedding(curr_seq)
            curr_input = pt.cat([spm_emb_proj, curr_emb], dim=1)

            logits = self._decode_sequence_exp(
                curr_input,
                self.spm_decoder,
                self.spm_norm,
                self.spm_classifier
            )

            next_token = logits[0, -1].argmax(-1).item()
            spm_ids.append(next_token)

            if next_token == self.eos_id:
                break

        suffix_ids = [self.bos_id]
        suffix_emb_proj = self.suffix_proj(embedding_proj)

        for _ in range(max_len[2]):
            curr_seq = pt.tensor([suffix_ids], device=device)
            curr_emb = self.suffix_embedding(curr_seq)
            curr_input = pt.cat([suffix_emb_proj, curr_emb], dim=1)

            logits = self._decode_sequence(
                curr_input,
                self.suffix_decoder_layers,
                self.suffix_norm,
                self.suffix_classifier
            )

            next_token = logits[0, -1].argmax(-1).item()
            suffix_ids.append(next_token)

            if next_token == self.eos_id:
                break

        return prefix_ids, spm_ids, suffix_ids

    def beam_search(self, embedding, beam_size=5, max_len=64, length_penalty=1.0):
        self.eval()
        if isinstance(max_len, int):
            max_len = (max_len, max_len, max_len)

        device = embedding.device

        if embedding.dim() == 1:
            embedding = embedding.unsqueeze(0)
        embedding_proj = embedding.unsqueeze(1)

        def _beam_search_component(emb_proj, embedding_fn, decoder_layers, norm_fn, classifier_fn, max_seq_len):
            sequences = [(
                [self.bos_id],
                0.0,
                None
            )]

            for step in range(max_seq_len):
                candidates = []

                for seq, score, _ in sequences:
                    if seq[-1] == self.eos_id:
                        candidates.append((seq, score, None))
                        continue

                    curr_input_ids = pt.tensor([seq], device=device)
                    curr_emb = embedding_fn(curr_input_ids)
                    curr_input = pt.cat([emb_proj, curr_emb], dim=1)

                    seq_len = curr_input.size(1)
                    attn_mask = pt.tril(pt.ones((1, seq_len, seq_len), dtype=pt.bool, device=device))

                    x = curr_input
                    for layer in decoder_layers:
                        residual = x
                        x_norm = layer['norm1'](x)
                        x = residual + layer['attn'](x_norm, attn_mask)

                        residual = x
                        x_norm = layer['norm2'](x)
                        x = residual + layer['ffn'](x_norm)

                    x = norm_fn(x)
                    logits = classifier_fn(x[0, -1])

                    log_probs = F.log_softmax(logits, dim=-1)
                    topk_log_probs, topk_indices = log_probs.topk(beam_size)

                    for i in range(beam_size):
                        token_id = topk_indices[i].item()
                        token_score = topk_log_probs[i].item()

                        new_seq = seq + [token_id]
                        new_score = score + token_score

                        if token_id == self.eos_id:
                            new_score = new_score / ((len(new_seq) - 1) ** length_penalty)

                        candidates.append((new_seq, new_score, None))

                candidates.sort(key=lambda x: x[1], reverse=True)

                sequences = candidates[:beam_size]

                if all(seq[-1] == self.eos_id for seq, _, _ in sequences):
                    break

            return sequences[0][0]

        prefix_ids = _beam_search_component(
            self.prefix_proj(embedding_proj),
            self.prefix_embedding,
            self.prefix_decoder_layers,
            self.prefix_norm,
            self.prefix_classifier,
            max_len[0]
        )

        spm_ids = _beam_search_component(
            self.spm_proj(embedding_proj),
            self.spm_embedding,
            self.spm_decoder_layers,
            self.spm_norm,
            self.spm_classifier,
            max_len[1]
        )

        suffix_ids = _beam_search_component(
            self.suffix_proj(embedding_proj),
            self.suffix_embedding,
            self.suffix_decoder_layers,
            self.suffix_norm,
            self.suffix_classifier,
            max_len[2]
        )

        return prefix_ids, spm_ids, suffix_ids


In [17]:
class ContrastiveLoss(nn.Module):
    def __init__(self, temperature=0.1):
        super().__init__()
        self.temperature = temperature

    def forward(self, embeddings, labels=None):
        batch_size = embeddings.shape[0]

        embeddings_normalized = F.normalize(embeddings, p=2, dim=1)

        similarity_matrix = pt.matmul(
            embeddings_normalized, embeddings_normalized.transpose(0, 1)
        ) / self.temperature

        mask = pt.eye(batch_size, dtype=pt.bool, device=embeddings.device)
        similarity_matrix = similarity_matrix.masked_fill(mask, -float('inf'))

        if labels is not None:
            labels = labels.contiguous().view(-1, 1)
            mask_pos = pt.eq(labels, labels.transpose(0, 1)).float()
            mask_neg = 1.0 - mask_pos

            similarity_pos = similarity_matrix * mask_pos
            similarity_neg = similarity_matrix * mask_neg

            numerator = pt.exp(similarity_pos).sum(dim=1)
            denominator = pt.exp(similarity_matrix).sum(dim=1)

            loss = -pt.log(numerator / denominator).mean()
        else:
            logits = F.log_softmax(similarity_matrix, dim=1)
            loss = -logits.mean()

        return loss

class EnhancedAutoencoderLoss(nn.Module):
    def __init__(self, pad_id, lambda_rec=1.0, lambda_contrastive=0.2):
        super().__init__()
        self.pad_id = pad_id
        self.lambda_rec = lambda_rec
        self.lambda_contrastive = lambda_contrastive

        self.rec_criterion = nn.CrossEntropyLoss(
            ignore_index=pad_id,
            label_smoothing=0.1,
            reduction='sum'
        )

        self.contrastive_criterion = ContrastiveLoss(temperature=0.1)

    def forward(self, prefix_logits, spm_logits, suffix_logits,
                prefix_ids, spm_ids, suffix_ids, embeddings=None, word_labels=None):
        prefix_loss = self.rec_criterion(
            prefix_logits.reshape(-1, prefix_logits.shape[-1]),
            prefix_ids.reshape(-1)
        )

        spm_loss = self.rec_criterion(
            spm_logits.reshape(-1, spm_logits.shape[-1]),
            spm_ids.reshape(-1)
        )

        suffix_loss = self.rec_criterion(
            suffix_logits.reshape(-1, suffix_logits.shape[-1]),
            suffix_ids.reshape(-1)
        )

        prefix_tokens = (prefix_ids != self.pad_id).sum()
        spm_tokens = (spm_ids != self.pad_id).sum()
        suffix_tokens = (suffix_ids != self.pad_id).sum()
        total_tokens = prefix_tokens + spm_tokens + suffix_tokens

        rec_loss = (prefix_loss + spm_loss + suffix_loss) / total_tokens

        contrastive_loss = pt.tensor(0.0, device=prefix_logits.device)
        if embeddings is not None:
            contrastive_loss = self.contrastive_criterion(embeddings, word_labels)

        total_loss = self.lambda_rec * rec_loss + self.lambda_contrastive * contrastive_loss

        return total_loss, {
            'rec_loss': rec_loss.item(),
            'contrastive_loss': contrastive_loss.item() if embeddings is not None else 0.0,
            'total_loss': total_loss.item()
        }

class CurriculumSampler:
    def __init__(self, dataset, mopiece, metrics=None, initial_difficulty=0.2, epochs=50):
        self.dataset = dataset
        self.mopiece = mopiece
        self.epochs = epochs
        self.initial_difficulty = initial_difficulty
        self.current_epoch = 0

        if metrics is None:
            self.word_complexities = []
            for idx in range(len(dataset)):
                prefix_ids, spm_ids, suffix_ids = dataset[idx]
                complexity = (
                    len([i for i in prefix_ids if i != mopiece.pad_id()]) +
                    len([i for i in spm_ids if i != mopiece.pad_id()]) +
                    len([i for i in suffix_ids if i != mopiece.pad_id()])
                )
                self.word_complexities.append(complexity)

            max_complexity = max(self.word_complexities)
            self.word_complexities = [c / max_complexity for c in self.word_complexities]
        else:
            self.word_complexities = metrics

        self.indices = list(range(len(dataset)))
        self.indices.sort(key=lambda i: self.word_complexities[i])

    def update_epoch(self, epoch):
        self.current_epoch = epoch

    def get_indices(self, batch_size):
        difficulty = self.initial_difficulty + (1.0 - self.initial_difficulty) * (self.current_epoch / self.epochs)

        num_samples = int(difficulty * len(self.dataset))
        num_samples = max(batch_size, min(num_samples, len(self.dataset)))

        candidate_indices = self.indices[:num_samples]

        selected_indices = pt.randperm(len(candidate_indices))[:batch_size].tolist()
        return [candidate_indices[i] for i in selected_indices]

    def get_batch(self, batch_size):
        indices = self.get_indices(batch_size)
        batch = [self.dataset[i] for i in indices]

        prefix_batch = pt.nn.utils.rnn.pad_sequence(
            [prefix for prefix, _, _ in batch],
            batch_first=True,
            padding_value=self.mopiece.pad_id()
        )

        spm_batch = pt.nn.utils.rnn.pad_sequence(
            [spm for _, spm, _ in batch],
            batch_first=True,
            padding_value=self.mopiece.pad_id()
        )

        suffix_batch = pt.nn.utils.rnn.pad_sequence(
            [suffix for _, _, suffix in batch],
            batch_first=True,
            padding_value=self.mopiece.pad_id()
        )

        return prefix_batch, spm_batch, suffix_batch, pt.tensor(indices)

In [18]:
def train_improved_word_autoencoder(encoder, decoder, dataset, mopiece,
                              num_epochs=75, batch_size=64, lr=4e-4,
                              weight_decay=0.01, device='cuda',
                              schedule_teacher_forcing=True):
    criterion = nn.CrossEntropyLoss(
        ignore_index=mopiece.pad_id(),
        label_smoothing=0.1
    )

    dataloader = pt.utils.data.DataLoader(
        dataset, batch_size=batch_size, shuffle=True,
        drop_last=True, collate_fn=collate_fn
    )

    optimizer = pt.optim.AdamW([
        {"params": encoder.parameters(), "lr": lr},
        {"params": decoder.parameters(), "lr": lr}
    ], weight_decay=weight_decay)

    lr_scheduler = pt.optim.lr_scheduler.OneCycleLR(
        optimizer,
        max_lr=lr,
        steps_per_epoch=len(dataloader),
        epochs=num_epochs,
        pct_start=0.3,
        div_factor=25,
        final_div_factor=1000
    )

    history = {
        'loss': [],
        'tf_ratio': []
    }

    encoder.train()
    decoder.train()

    for epoch in (pbar := trange(num_epochs, desc='Обучение')):
        epoch_loss = 0.0

        if schedule_teacher_forcing:
            tf_ratio = max(0.5, 1.0 - 0.5 * (epoch / (num_epochs - 1)))
        else:
            tf_ratio = 1.0

        history['tf_ratio'].append(tf_ratio)

        for prefix_ids, spm_ids, suffix_ids in tqdm(dataloader, desc=f'Эпоха {epoch+1}', leave=False):
            prefix_ids = prefix_ids.to(device)
            spm_ids = spm_ids.to(device)
            suffix_ids = suffix_ids.to(device)

            optimizer.zero_grad()

            embeddings = encoder(prefix_ids, spm_ids, suffix_ids)

            embeddings = F.dropout(embeddings, p=0.1, training=True)

            prefix_logits, spm_logits, suffix_logits = decoder(
                prefix_ids[..., :-1],
                spm_ids[..., :-1],
                suffix_ids[..., :-1],
                embeddings
            )

            prefix_loss = criterion(
                prefix_logits.reshape(-1, prefix_logits.shape[-1]),
                prefix_ids[..., 1:].reshape(-1)
            )

            spm_loss = criterion(
                spm_logits.reshape(-1, spm_logits.shape[-1]),
                spm_ids[..., 1:].reshape(-1)
            )

            suffix_loss = criterion(
                suffix_logits.reshape(-1, suffix_logits.shape[-1]),
                suffix_ids[..., 1:].reshape(-1)
            )

            loss = prefix_loss + spm_loss + suffix_loss

            loss.backward()

            pt.nn.utils.clip_grad_norm_(
                list(encoder.parameters()) + list(decoder.parameters()),
                max_norm=1.0
            )

            optimizer.step()

            lr_scheduler.step()

            epoch_loss += loss.item()

        avg_loss = epoch_loss / len(dataloader)
        history['loss'].append(avg_loss)

        pbar.set_postfix_str(
            f'Loss: {avg_loss:.4f}, TF: {tf_ratio:.2f}, LR: {optimizer.param_groups[0]["lr"]:.6f}'
        )

        if (epoch + 1) % 10 == 0 or epoch == num_epochs - 1:
            test_word_generation(encoder, decoder, dataset, mopiece, device, beam_search=True)

    return history, encoder, decoder

def test_word_generation(encoder, decoder, dataset, mopiece, device, num_samples=3, beam_search=True):
    encoder.eval()
    decoder.eval()

    indices = pt.randperm(len(dataset))[:num_samples].tolist()

    print("\n--- Тестирование ---")
    correct_count = 0

    for idx in indices:
        prefix_ids, spm_ids, suffix_ids = dataset[idx]

        prefix_ids = prefix_ids.to(device)
        spm_ids = spm_ids.to(device)
        suffix_ids = suffix_ids.to(device)

        original_word = mopiece.decode_word(
            prefix_ids.tolist()[1:-1],
            spm_ids.tolist()[1:-1],
            suffix_ids.tolist()[1:-1]
        )

        with pt.no_grad():
            embedding = encoder(
                prefix_ids.unsqueeze(0),
                spm_ids.unsqueeze(0),
                suffix_ids.unsqueeze(0)
            ).squeeze(0)

            if beam_search:
                gen_prefix_ids, gen_spm_ids, gen_suffix_ids = decoder.beam_search(
                    embedding, beam_size=5
                )
            else:
                gen_prefix_ids, gen_spm_ids, gen_suffix_ids = decoder.inference(
                    embedding
                )

            generated_word = mopiece.decode_word(
                [t for t in gen_prefix_ids[1:] if t != decoder.eos_id and t != decoder.pad_id],
                [t for t in gen_spm_ids[1:] if t != decoder.eos_id and t != decoder.pad_id],
                [t for t in gen_suffix_ids[1:] if t != decoder.eos_id and t != decoder.pad_id]
            )

            is_correct = generated_word == original_word
            if is_correct:
                correct_count += 1

            print(f"Оригинал: '{original_word}'")
            print(f"Сгенерировано: '{generated_word}' {'✓' if is_correct else '✗'}")
            print("-" * 40)

    encoder.train()
    decoder.train()


In [19]:
encoder = EnhancedWordEncoder(
    mopiece.vocab_size(),
    mopiece.pad_id(),
    embedding_dim=256,
    spm_dim=512,
    suffix_dim=256,
    prefix_dim=256,
    spm_layers=8,
    spm_heads=16,
    ffn_hidden_dim=1024,
    dropout=0.1
).to(device)

decoder = EnhancedWordDecoder(
    mopiece.vocab_size(),
    mopiece.pad_id(),
    mopiece.bos_id(),
    mopiece.eos_id(),
    embedding_dim=256,
    spm_dim=512,
    suffix_dim=256,
    prefix_dim=256,
    num_layers=8,
    num_heads=8,
    ffn_dim=1024,
    dropout=0.1
).to(device)

history, trained_encoder, trained_decoder = train_improved_word_autoencoder(
    encoder,
    decoder,
    dset,
    mopiece,
    num_epochs=100,
    batch_size=32,
    lr=3e-4,
    weight_decay=0.01,
    device=device,
    schedule_teacher_forcing=True
)

pt.save({
    'encoder_state_dict': trained_encoder.state_dict(),
    'decoder_state_dict': trained_decoder.state_dict(),
    'history': history
}, 'improved_word_autoencoder.pt')


Обучение:   0%|          | 0/100 [00:00<?, ?it/s]
Эпоха 1:   0%|          | 0/405 [00:00<?, ?it/s][A
Эпоха 1:   0%|          | 1/405 [00:04<27:03,  4.02s/it][A
Эпоха 1:   0%|          | 2/405 [00:04<13:12,  1.97s/it][A
Эпоха 1:   1%|          | 3/405 [00:05<09:42,  1.45s/it][A
Эпоха 1:   1%|          | 4/405 [00:06<08:35,  1.29s/it][A
Эпоха 1:   1%|          | 5/405 [00:06<06:14,  1.07it/s][A
Эпоха 1:   1%|▏         | 6/405 [00:07<05:22,  1.24it/s][A
Эпоха 1:   2%|▏         | 7/405 [00:07<04:12,  1.58it/s][A
Эпоха 1:   2%|▏         | 8/405 [00:07<03:29,  1.90it/s][A
Эпоха 1:   2%|▏         | 9/405 [00:08<03:00,  2.20it/s][A
Эпоха 1:   2%|▏         | 10/405 [00:08<02:39,  2.48it/s][A
Эпоха 1:   3%|▎         | 11/405 [00:08<02:25,  2.70it/s][A
Эпоха 1:   3%|▎         | 12/405 [00:09<02:17,  2.86it/s][A
Эпоха 1:   3%|▎         | 13/405 [00:09<02:11,  2.99it/s][A
Эпоха 1:   3%|▎         | 14/405 [00:09<02:06,  3.10it/s][A
Эпоха 1:   4%|▎         | 15/405 [00:10<03:51,  1.68i


--- Тестирование генерации слов ---
Оригинал: 'fillets'
Сгенерировано: 'fillets' ✓
----------------------------------------


Обучение:  10%|█         | 10/100 [20:58<3:09:40, 126.44s/it, Loss: 4.0232, TF: 0.95, LR: 0.000084]

Оригинал: 'nahum'
Сгенерировано: 'nahum' ✓
----------------------------------------
Оригинал: 'revive'
Сгенерировано: 'revive' ✓
----------------------------------------
Точность: 100.0% (3/3)
--- Конец тестирования ---




Эпоха 11:   0%|          | 0/405 [00:00<?, ?it/s][A
Эпоха 11:   0%|          | 1/405 [00:00<02:17,  2.93it/s][A
Эпоха 11:   0%|          | 2/405 [00:00<02:08,  3.14it/s][A
Эпоха 11:   1%|          | 3/405 [00:00<02:04,  3.24it/s][A
Эпоха 11:   1%|          | 4/405 [00:01<02:01,  3.31it/s][A
Эпоха 11:   1%|          | 5/405 [00:01<02:01,  3.29it/s][A
Эпоха 11:   1%|▏         | 6/405 [00:01<02:02,  3.25it/s][A
Эпоха 11:   2%|▏         | 7/405 [00:02<02:01,  3.27it/s][A
Эпоха 11:   2%|▏         | 8/405 [00:02<02:01,  3.27it/s][A
Эпоха 11:   2%|▏         | 9/405 [00:02<01:58,  3.35it/s][A
Эпоха 11:   2%|▏         | 10/405 [00:03<01:56,  3.38it/s][A
Эпоха 11:   3%|▎         | 11/405 [00:03<01:56,  3.38it/s][A
Эпоха 11:   3%|▎         | 12/405 [00:03<01:55,  3.39it/s][A
Эпоха 11:   3%|▎         | 13/405 [00:03<01:55,  3.39it/s][A
Эпоха 11:   3%|▎         | 14/405 [00:04<01:58,  3.31it/s][A
Эпоха 11:   4%|▎         | 15/405 [00:04<01:55,  3.38it/s][A
Эпоха 11:   4%|▍         


--- Тестирование генерации слов ---
Оригинал: 'eleasah'
Сгенерировано: 'elezia' ✗
----------------------------------------
Оригинал: 'prisons'
Сгенерировано: 'prickson' ✗
----------------------------------------


Обучение:  20%|██        | 20/100 [41:13<2:42:29, 121.87s/it, Loss: 3.3682, TF: 0.90, LR: 0.000228]

Оригинал: 'niger'
Сгенерировано: 'nigifer' ✗
----------------------------------------
Точность: 0.0% (0/3)
--- Конец тестирования ---




Эпоха 21:   0%|          | 0/405 [00:00<?, ?it/s][A
Эпоха 21:   0%|          | 1/405 [00:00<02:17,  2.94it/s][A
Эпоха 21:   0%|          | 2/405 [00:00<02:09,  3.11it/s][A
Эпоха 21:   1%|          | 3/405 [00:00<02:05,  3.20it/s][A
Эпоха 21:   1%|          | 4/405 [00:01<02:04,  3.21it/s][A
Эпоха 21:   1%|          | 5/405 [00:01<02:04,  3.20it/s][A
Эпоха 21:   1%|▏         | 6/405 [00:01<02:03,  3.24it/s][A
Эпоха 21:   2%|▏         | 7/405 [00:02<02:01,  3.27it/s][A
Эпоха 21:   2%|▏         | 8/405 [00:02<01:58,  3.35it/s][A
Эпоха 21:   2%|▏         | 9/405 [00:02<01:57,  3.38it/s][A
Эпоха 21:   2%|▏         | 10/405 [00:03<01:56,  3.39it/s][A
Эпоха 21:   3%|▎         | 11/405 [00:03<01:53,  3.46it/s][A
Эпоха 21:   3%|▎         | 12/405 [00:03<01:53,  3.47it/s][A
Эпоха 21:   3%|▎         | 13/405 [00:03<01:51,  3.51it/s][A
Эпоха 21:   3%|▎         | 14/405 [00:04<01:50,  3.55it/s][A
Эпоха 21:   4%|▎         | 15/405 [00:04<01:49,  3.55it/s][A
Эпоха 21:   4%|▍         


--- Тестирование генерации слов ---
Оригинал: 'amminadab'
Сгенерировано: 'amminadab' ✓
----------------------------------------
Оригинал: 'profession'
Сгенерировано: 'profsion' ✗
----------------------------------------


Обучение:  30%|███       | 30/100 [1:01:24<2:21:09, 121.00s/it, Loss: 3.2717, TF: 0.85, LR: 0.000300]

Оригинал: 'fried'
Сгенерировано: 'fri' ✗
----------------------------------------
Точность: 33.3% (1/3)
--- Конец тестирования ---




Эпоха 31:   0%|          | 0/405 [00:00<?, ?it/s][A
Эпоха 31:   0%|          | 1/405 [00:00<02:11,  3.08it/s][A
Эпоха 31:   0%|          | 2/405 [00:00<02:05,  3.20it/s][A
Эпоха 31:   1%|          | 3/405 [00:00<02:02,  3.29it/s][A
Эпоха 31:   1%|          | 4/405 [00:01<02:00,  3.34it/s][A
Эпоха 31:   1%|          | 5/405 [00:01<02:00,  3.32it/s][A
Эпоха 31:   1%|▏         | 6/405 [00:01<01:53,  3.51it/s][A
Эпоха 31:   2%|▏         | 7/405 [00:02<01:57,  3.39it/s][A
Эпоха 31:   2%|▏         | 8/405 [00:02<01:58,  3.36it/s][A
Эпоха 31:   2%|▏         | 9/405 [00:02<01:58,  3.34it/s][A
Эпоха 31:   2%|▏         | 10/405 [00:02<01:57,  3.36it/s][A
Эпоха 31:   3%|▎         | 11/405 [00:03<01:57,  3.35it/s][A
Эпоха 31:   3%|▎         | 12/405 [00:03<02:00,  3.27it/s][A
Эпоха 31:   3%|▎         | 13/405 [00:03<01:56,  3.35it/s][A
Эпоха 31:   3%|▎         | 14/405 [00:04<01:58,  3.30it/s][A
Эпоха 31:   4%|▎         | 15/405 [00:04<01:58,  3.30it/s][A
Эпоха 31:   4%|▍         


--- Тестирование генерации слов ---
Оригинал: 'lepers'
Сгенерировано: 'evaners' ✗
----------------------------------------
Оригинал: 'eleventh'
Сгенерировано: 'eleven' ✗
----------------------------------------


Обучение:  40%|████      | 40/100 [1:22:31<2:06:31, 126.52s/it, Loss: 3.1887, TF: 0.80, LR: 0.000285]

Оригинал: 'shimrith'
Сгенерировано: 'shimri' ✗
----------------------------------------
Точность: 0.0% (0/3)
--- Конец тестирования ---




Эпоха 41:   0%|          | 0/405 [00:00<?, ?it/s][A
Эпоха 41:   0%|          | 1/405 [00:00<02:09,  3.13it/s][A
Эпоха 41:   0%|          | 2/405 [00:00<02:08,  3.12it/s][A
Эпоха 41:   1%|          | 3/405 [00:00<02:08,  3.13it/s][A
Эпоха 41:   1%|          | 4/405 [00:01<02:08,  3.12it/s][A
Эпоха 41:   1%|          | 5/405 [00:01<02:04,  3.21it/s][A
Эпоха 41:   1%|▏         | 6/405 [00:01<02:04,  3.22it/s][A
Эпоха 41:   2%|▏         | 7/405 [00:02<02:03,  3.21it/s][A
Эпоха 41:   2%|▏         | 8/405 [00:02<02:02,  3.24it/s][A
Эпоха 41:   2%|▏         | 9/405 [00:02<02:01,  3.27it/s][A
Эпоха 41:   2%|▏         | 10/405 [00:03<02:00,  3.27it/s][A
Эпоха 41:   3%|▎         | 11/405 [00:03<01:57,  3.34it/s][A
Эпоха 41:   3%|▎         | 12/405 [00:03<01:57,  3.34it/s][A
Эпоха 41:   3%|▎         | 13/405 [00:03<01:57,  3.34it/s][A
Эпоха 41:   3%|▎         | 14/405 [00:04<01:55,  3.38it/s][A
Эпоха 41:   4%|▎         | 15/405 [00:04<01:55,  3.36it/s][A
Эпоха 41:   4%|▍         


--- Тестирование генерации слов ---
Оригинал: 'highly'
Сгенерировано: 'highly' ✓
----------------------------------------


Обучение:  50%|█████     | 50/100 [1:42:44<1:41:57, 122.35s/it, Loss: 3.1428, TF: 0.75, LR: 0.000244]

Оригинал: 'filledst'
Сгенерировано: 'fillst' ✗
----------------------------------------
Оригинал: 'winter'
Сгенерировано: 'winter' ✓
----------------------------------------
Точность: 66.7% (2/3)
--- Конец тестирования ---




Эпоха 51:   0%|          | 0/405 [00:00<?, ?it/s][A
Эпоха 51:   0%|          | 1/405 [00:00<02:10,  3.10it/s][A
Эпоха 51:   0%|          | 2/405 [00:00<02:10,  3.09it/s][A
Эпоха 51:   1%|          | 3/405 [00:00<02:04,  3.24it/s][A
Эпоха 51:   1%|          | 4/405 [00:01<02:05,  3.20it/s][A
Эпоха 51:   1%|          | 5/405 [00:01<02:06,  3.17it/s][A
Эпоха 51:   1%|▏         | 6/405 [00:01<02:04,  3.20it/s][A
Эпоха 51:   2%|▏         | 7/405 [00:02<02:02,  3.24it/s][A
Эпоха 51:   2%|▏         | 8/405 [00:02<02:01,  3.26it/s][A
Эпоха 51:   2%|▏         | 9/405 [00:02<02:03,  3.22it/s][A
Эпоха 51:   2%|▏         | 10/405 [00:03<02:01,  3.26it/s][A
Эпоха 51:   3%|▎         | 11/405 [00:03<02:00,  3.28it/s][A
Эпоха 51:   3%|▎         | 12/405 [00:03<02:02,  3.22it/s][A
Эпоха 51:   3%|▎         | 13/405 [00:04<02:04,  3.14it/s][A
Эпоха 51:   3%|▎         | 14/405 [00:04<02:05,  3.12it/s][A
Эпоха 51:   4%|▎         | 15/405 [00:04<02:06,  3.07it/s][A
Эпоха 51:   4%|▍         


--- Тестирование генерации слов ---
Оригинал: 'pourtrayed'
Сгенерировано: 'pourted' ✗
----------------------------------------
Оригинал: 'entangled'
Сгенерировано: 'tang' ✗
----------------------------------------


Обучение:  60%|██████    | 60/100 [2:03:46<1:23:42, 125.57s/it, Loss: 3.1158, TF: 0.70, LR: 0.000183]

Оригинал: 'criest'
Сгенерировано: 'criest' ✓
----------------------------------------
Точность: 33.3% (1/3)
--- Конец тестирования ---




Эпоха 61:   0%|          | 0/405 [00:00<?, ?it/s][A
Эпоха 61:   0%|          | 1/405 [00:00<02:22,  2.83it/s][A
Эпоха 61:   0%|          | 2/405 [00:00<02:09,  3.10it/s][A
Эпоха 61:   1%|          | 3/405 [00:00<02:06,  3.17it/s][A
Эпоха 61:   1%|          | 4/405 [00:01<02:05,  3.20it/s][A
Эпоха 61:   1%|          | 5/405 [00:01<02:07,  3.14it/s][A
Эпоха 61:   1%|▏         | 6/405 [00:01<02:10,  3.06it/s][A
Эпоха 61:   2%|▏         | 7/405 [00:02<02:11,  3.02it/s][A
Эпоха 61:   2%|▏         | 8/405 [00:02<02:08,  3.08it/s][A
Эпоха 61:   2%|▏         | 9/405 [00:02<02:05,  3.16it/s][A
Эпоха 61:   2%|▏         | 10/405 [00:03<02:04,  3.16it/s][A
Эпоха 61:   3%|▎         | 11/405 [00:03<02:05,  3.13it/s][A
Эпоха 61:   3%|▎         | 12/405 [00:03<02:04,  3.15it/s][A
Эпоха 61:   3%|▎         | 13/405 [00:04<02:03,  3.17it/s][A
Эпоха 61:   3%|▎         | 14/405 [00:04<02:05,  3.11it/s][A
Эпоха 61:   4%|▎         | 15/405 [00:04<02:04,  3.13it/s][A
Эпоха 61:   4%|▍         


--- Тестирование генерации слов ---
Оригинал: 'shouting'
Сгенерировано: 'shouting' ✓
----------------------------------------


Обучение:  70%|███████   | 70/100 [2:25:02<1:03:02, 126.09s/it, Loss: 3.0928, TF: 0.65, LR: 0.000117]

Оригинал: 'riphath'
Сгенерировано: 'riphath' ✓
----------------------------------------
Оригинал: 'bray'
Сгенерировано: 'braed' ✗
----------------------------------------
Точность: 66.7% (2/3)
--- Конец тестирования ---




Эпоха 71:   0%|          | 0/405 [00:00<?, ?it/s][A
Эпоха 71:   0%|          | 1/405 [00:00<02:07,  3.18it/s][A
Эпоха 71:   0%|          | 2/405 [00:00<02:01,  3.30it/s][A
Эпоха 71:   1%|          | 3/405 [00:00<02:00,  3.33it/s][A
Эпоха 71:   1%|          | 4/405 [00:01<02:00,  3.33it/s][A
Эпоха 71:   1%|          | 5/405 [00:01<02:01,  3.28it/s][A
Эпоха 71:   1%|▏         | 6/405 [00:01<02:01,  3.29it/s][A
Эпоха 71:   2%|▏         | 7/405 [00:02<02:05,  3.16it/s][A
Эпоха 71:   2%|▏         | 8/405 [00:02<02:05,  3.17it/s][A
Эпоха 71:   2%|▏         | 9/405 [00:02<02:03,  3.19it/s][A
Эпоха 71:   2%|▏         | 10/405 [00:03<02:04,  3.18it/s][A
Эпоха 71:   3%|▎         | 11/405 [00:03<02:01,  3.24it/s][A
Эпоха 71:   3%|▎         | 12/405 [00:03<01:59,  3.29it/s][A
Эпоха 71:   3%|▎         | 13/405 [00:04<02:00,  3.24it/s][A
Эпоха 71:   3%|▎         | 14/405 [00:04<02:04,  3.13it/s][A
Эпоха 71:   4%|▎         | 15/405 [00:04<02:06,  3.08it/s][A
Эпоха 71:   4%|▍         


--- Тестирование генерации слов ---
Оригинал: 'benefits'
Сгенерировано: 'nefs' ✗
----------------------------------------


Обучение:  80%|████████  | 80/100 [2:46:16<42:36, 127.82s/it, Loss: 3.0830, TF: 0.60, LR: 0.000056]

Оригинал: 'famines'
Сгенерировано: 'famin' ✗
----------------------------------------
Оригинал: 'shisha'
Сгенерировано: 'shisha' ✓
----------------------------------------
Точность: 33.3% (1/3)
--- Конец тестирования ---




Эпоха 81:   0%|          | 0/405 [00:00<?, ?it/s][A
Эпоха 81:   0%|          | 1/405 [00:00<02:20,  2.88it/s][A
Эпоха 81:   0%|          | 2/405 [00:00<02:16,  2.96it/s][A
Эпоха 81:   1%|          | 3/405 [00:01<02:16,  2.95it/s][A
Эпоха 81:   1%|          | 4/405 [00:01<02:10,  3.08it/s][A
Эпоха 81:   1%|          | 5/405 [00:01<02:05,  3.20it/s][A
Эпоха 81:   1%|▏         | 6/405 [00:01<02:02,  3.25it/s][A
Эпоха 81:   2%|▏         | 7/405 [00:02<02:03,  3.23it/s][A
Эпоха 81:   2%|▏         | 8/405 [00:02<02:04,  3.19it/s][A
Эпоха 81:   2%|▏         | 9/405 [00:02<02:06,  3.12it/s][A
Эпоха 81:   2%|▏         | 10/405 [00:03<02:05,  3.16it/s][A
Эпоха 81:   3%|▎         | 11/405 [00:03<02:05,  3.14it/s][A
Эпоха 81:   3%|▎         | 12/405 [00:03<02:05,  3.13it/s][A
Эпоха 81:   3%|▎         | 13/405 [00:04<02:03,  3.17it/s][A
Эпоха 81:   3%|▎         | 14/405 [00:04<02:02,  3.20it/s][A
Эпоха 81:   4%|▎         | 15/405 [00:04<02:02,  3.19it/s][A
Эпоха 81:   4%|▍         


--- Тестирование генерации слов ---
Оригинал: 'extol'
Сгенерировано: 'extol' ✓
----------------------------------------
Оригинал: 'slay'
Сгенерировано: 'slay' ✓
----------------------------------------


Обучение:  90%|█████████ | 90/100 [3:07:26<21:11, 127.12s/it, Loss: 3.0771, TF: 0.55, LR: 0.000015]

Оригинал: 'risen'
Сгенерировано: 'rissahen' ✗
----------------------------------------
Точность: 66.7% (2/3)
--- Конец тестирования ---




Эпоха 91:   0%|          | 0/405 [00:00<?, ?it/s][A
Эпоха 91:   0%|          | 1/405 [00:00<02:11,  3.06it/s][A
Эпоха 91:   0%|          | 2/405 [00:00<02:07,  3.17it/s][A
Эпоха 91:   1%|          | 3/405 [00:00<02:03,  3.26it/s][A
Эпоха 91:   1%|          | 4/405 [00:01<01:58,  3.38it/s][A
Эпоха 91:   1%|          | 5/405 [00:01<01:59,  3.35it/s][A
Эпоха 91:   1%|▏         | 6/405 [00:01<01:57,  3.39it/s][A
Эпоха 91:   2%|▏         | 7/405 [00:02<01:56,  3.42it/s][A
Эпоха 91:   2%|▏         | 8/405 [00:02<01:55,  3.44it/s][A
Эпоха 91:   2%|▏         | 9/405 [00:02<01:55,  3.43it/s][A
Эпоха 91:   2%|▏         | 10/405 [00:02<01:56,  3.40it/s][A
Эпоха 91:   3%|▎         | 11/405 [00:03<01:58,  3.32it/s][A
Эпоха 91:   3%|▎         | 12/405 [00:03<01:59,  3.29it/s][A
Эпоха 91:   3%|▎         | 13/405 [00:03<01:57,  3.32it/s][A
Эпоха 91:   3%|▎         | 14/405 [00:04<01:57,  3.32it/s][A
Эпоха 91:   4%|▎         | 15/405 [00:04<01:56,  3.34it/s][A
Эпоха 91:   4%|▍         


--- Тестирование генерации слов ---
Оригинал: 'ammah'
Сгенерировано: 'ammah' ✓
----------------------------------------


Обучение: 100%|██████████| 100/100 [3:27:36<00:00, 124.56s/it, Loss: 3.0749, TF: 0.50, LR: 0.000000]

Оригинал: 'vehemently'
Сгенерировано: 'vehement' ✗
----------------------------------------
Оригинал: 'throughout'
Сгенерировано: 'throughoutp' ✗
----------------------------------------
Точность: 33.3% (1/3)
--- Конец тестирования ---




