In [1]:
from tokenizer import train_mopiece, MOPiece
from word_autoencoder import WordEncoder, WordDecoder
import torch as pt
from torch import nn
import polars as pl
import regex as re
from tqdm.notebook import trange, tqdm

device = ('cuda' if pt.cuda.is_available() else 'mps' if pt.backends.mps.is_available() else 'cpu')

In [2]:
suffixes = pl.read_csv('data/morphynet/suffixes.csv')['suffix']
prefixes = pl.read_csv('data/morphynet/prefixes.csv')['prefix']

train_mopiece('____tokenizer', ['data/text/bible.txt'], prefixes, suffixes, 3000, spm_model_type='bpe')

mopiece = MOPiece('____tokenizer')

In [3]:
words = set()
reg = re.compile(r'([^\p{L}\p{M}\p{N}\s]+|\s)')
with open('data/text/bible.txt', mode='r', encoding='utf8') as file:
            for line in file.readlines():
                for word in reg.split(line):
                    if word == '' or word == ' ':
                        continue
                    word = word.lower()
                    words.add(word)

In [4]:
class WordDataset(pt.utils.data.Dataset):
    def __init__(self, words, mopiece, device=None):
        super().__init__()
        bos_id, eos_id = mopiece.bos_id(), mopiece.eos_id()
        self.words = []
        for word in words:
            prefix_ids, spm_ids, suffix_ids = mopiece.encode_word(word)
            self.words.append((
                pt.tensor([bos_id] + prefix_ids + [eos_id], dtype=pt.long, device=device), 
                pt.tensor([bos_id] + spm_ids    + [eos_id], dtype=pt.long, device=device), 
                pt.tensor([bos_id] + suffix_ids + [eos_id], dtype=pt.long, device=device)
            ))
        
    def __len__(self):
        return len(self.words)
    
    def __getitem__(self, index):
        return self.words[index]
    
class SequenceCrossEntropyLoss(nn.CrossEntropyLoss):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs, reduction='sum')

    def forward(self, prefix_logits, spm_logits, suffix_logits, prefix_ids, spm_ids, suffix_ids):
        loss = 0.
        for prefix_logits_seq, spm_logits_seq, suffix_logits_seq, prefix_ids_seq, spm_ids_seq, suffix_ids_seq in zip(prefix_logits, spm_logits, suffix_logits, prefix_ids, spm_ids, suffix_ids):
            loss += (super().forward(prefix_logits_seq, prefix_ids_seq) + super().forward(spm_logits_seq, spm_ids_seq) + super().forward(suffix_logits_seq, suffix_ids_seq)) / (len(prefix_ids_seq) + len(spm_ids_seq) + len(suffix_ids_seq))
        return loss / prefix_logits.shape[0]
    
dset = WordDataset(words, mopiece, device=device)

def collate_fn(batch):
    return (pt.nn.utils.rnn.pad_sequence([prefix for prefix, spm, suffix in batch], batch_first=True, padding_value=mopiece.pad_id(), padding_side='right'),
            pt.nn.utils.rnn.pad_sequence([spm    for prefix, spm, suffix in batch], batch_first=True, padding_value=mopiece.pad_id(), padding_side='right'),
            pt.nn.utils.rnn.pad_sequence([suffix for prefix, spm, suffix in batch], batch_first=True, padding_value=mopiece.pad_id(), padding_side='right'))

In [5]:
encoder = WordEncoder(mopiece.vocab_size(), mopiece.pad_id(), 128, ffn_hidden_dim=256, expansion_factor=4).to(device)
decoder = WordDecoder(mopiece.vocab_size(), mopiece.pad_id(), mopiece.bos_id(), mopiece.eos_id(), 128, num_layers=6).to(device)

In [6]:
criterion = SequenceCrossEntropyLoss(ignore_index=mopiece.pad_id())

epochs = 30
batch_size = 128

loader = pt.utils.data.DataLoader(dset, batch_size, shuffle=True, drop_last=True, collate_fn=collate_fn)

optim = pt.optim.AdamW([{"params": encoder.parameters()}, {"params": decoder.parameters()}])
lr_scheduler = pt.optim.lr_scheduler.CosineAnnealingLR(optim, epochs)

encoder.train()
decoder.train()

log_perplexities = []
for epoch in (pbar := trange(epochs, desc='Epoch')):
    log_perplexity_sum = 0.
    for prefix_ids, spm_ids, suffix_ids in tqdm(loader, desc='Training', leave=False):
        optim.zero_grad()

        embedding = encoder(prefix_ids, spm_ids, suffix_ids)

        prefix_logits, spm_logits, suffix_logits = decoder(prefix_ids[..., :-1], spm_ids[..., :-1], suffix_ids[..., :-1], embedding)

        loss = criterion(prefix_logits, spm_logits, suffix_logits, prefix_ids[..., 1:], spm_ids[..., 1:], suffix_ids[..., 1:])

        loss.backward()
        optim.step()

        log_perplexity_sum += loss.item()

    log_perplexity = log_perplexity_sum / len(loader)
    log_perplexities.append(log_perplexity)
    pbar.set_postfix_str(f'log-perplexity: {log_perplexity:.2f}')
        


Epoch:   0%|          | 0/30 [00:00<?, ?it/s]

Training:   0%|          | 0/101 [00:00<?, ?it/s]

Training:   0%|          | 0/101 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [7]:
from random import sample

with pt.no_grad():
    encoder.eval()
    decoder.eval()
    bos_id, eos_id = mopiece.bos_id(), mopiece.eos_id()

    word = sample(list(words), 1)[0]
    prefix_ids, spm_ids, suffix_ids = map(lambda x: [bos_id] + x + [eos_id], mopiece.encode_word(word))
    print('input:')
    print(prefix_ids, spm_ids, suffix_ids)
    print(word)


    embedding = encoder(pt.tensor(prefix_ids, dtype=pt.long, device=device), pt.tensor(spm_ids, dtype=pt.long, device=device), pt.tensor(suffix_ids, dtype=pt.long, device=device))

    out_prefix_ids, out_spm_ids, out_suffix_ids = decoder.inference(embedding)
    print('\noutput:')
    print(out_prefix_ids, out_spm_ids, out_suffix_ids)
    print(mopiece.decode_word(out_prefix_ids, out_spm_ids, out_suffix_ids))

input:
[0, 1] [0, 2108, 1] [0, 137, 1]
thunders
torch.bool
torch.bool

output:
[0, 602, 1] [0, 1] [0, 387, 1]
pros
