In [57]:
import torch
import numpy as np
from torch import nn

In [58]:
sorted_file = 'train.mn'
unsorted_file = 'train.mgl'

START_TOKEN = ''
PADDING_TOKEN = ''
END_TOKEN = ''

unsorted_vocabulary = [START_TOKEN, ' ', '!', '"', '#', '₮', '%', '&', "'", '(', ')', '*', '+', ',', '-', '.', '/', 
                      '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', ':', '<', '=', '>', '?', 'ˌ', 
                      'а', 'б', 'в', 'г', 'д', 'е', 'ё', 'ж', 
                      'з', 'и', 'й', 'к', 'л', 'м', 'н', 'о', 'ө', 'п', 'р', 'с', 'т', 'у', 'ү', 
                      'ф', 'х', 'ц', 'ч', 'ш', 
                      'щ', 'ъ', 'ь', 'ы', 'э', 
                      'ю', 'я', PADDING_TOKEN, END_TOKEN]

sorted_vocabulary = [START_TOKEN, ' ', '!', '"', '#', '₮', '%', '&', "'", '(', ')', '*', '+', ',', '-', '.', '/', 
                      '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', ':', '<', '=', '>', '?', 'ˌ', 
                      'а', 'б', 'в', 'г', 'д', 'е', 'ё', 'ж', 
                      'з', 'и', 'й', 'к', 'л', 'м', 'н', 'о', 'ө', 'п', 'р', 'с', 'т', 'у', 'ү', 
                      'ф', 'х', 'ц', 'ч', 'ш', 
                      'щ', 'ъ', 'ь', 'ы', 'э', 
                      'ю', 'я', PADDING_TOKEN, END_TOKEN]

In [59]:
index_to_unsorted = {k:v for k,v in enumerate(unsorted_vocabulary)}
unsorted_to_index = {v:k for k,v in enumerate(unsorted_vocabulary)}
index_to_sorted = {k:v for k,v in enumerate(sorted_vocabulary)}
sorted_to_index = {v:k for k,v in enumerate(sorted_vocabulary)}

In [60]:
with open(sorted_file, 'r') as file:
    sorted_sentences = file.readlines()
with open(unsorted_file, 'r') as file:
    unsorted_sentences = file.readlines()

TOTAL_SENTENCES = 3
sorted_sentences = sorted_sentences[:TOTAL_SENTENCES]
unsorted_sentences = unsorted_sentences[:TOTAL_SENTENCES]
sorted_sentences = [sentence.rstrip('\n') for sentence in sorted_sentences]
unsorted_sentences = [sentence.rstrip('\n') for sentence in unsorted_sentences]

In [61]:
sorted_sentences[:3]


['байна уу', 'байна', 'миний нэрийг эрхэмхүү гэдэг']

In [62]:
unsorted_sentences[:3]

['бну', 'бн', 'минии нэрииг эрхэмхүү гэдэг']

In [63]:
max(len(x) for x in unsorted_sentences), max(len(x) for x in sorted_sentences),

(27, 27)

In [64]:
PERCENTILE = 97
print( f"{PERCENTILE}th percentile length unsorted: {np.percentile([len(x) for x in unsorted_sentences], PERCENTILE)}" )
print( f"{PERCENTILE}th percentile length sorted: {np.percentile([len(x) for x in sorted_sentences], PERCENTILE)}" )

97th percentile length unsorted: 25.56
97th percentile length sorted: 25.86


In [65]:
max_sequence_length = 200

def is_valid_tokens(sentence, vocab):
    for token in list(set(sentence)):
        if token not in vocab:
            return False
    return True

def is_valid_length(sentence, max_sequence_length):
    return len(list(sentence)) < (max_sequence_length - 1) # need to re-add the end token so leaving 1 space

valid_sentence_indicies = []
for index in range(len(unsorted_sentences)):
    unsorted_sentence, sorted_sentence = unsorted_sentences[index], sorted_sentences[index]
    if is_valid_length(unsorted_sentence, max_sequence_length) \
      and is_valid_length(sorted_sentence, max_sequence_length) \
      and is_valid_tokens(unsorted_sentence, unsorted_vocabulary):
        valid_sentence_indicies.append(index)

print(f"Number of sentences: {len(unsorted_sentences)}")
print(f"Number of valid sentences: {len(valid_sentence_indicies)}")

Number of sentences: 3
Number of valid sentences: 3


In [66]:
unsorted_sentences = [unsorted_sentences[i] for i in valid_sentence_indicies]
sorted_sentences = [sorted_sentences[i] for i in valid_sentence_indicies]

In [67]:
unsorted_sentences[:3]

['бну', 'бн', 'минии нэрииг эрхэмхүү гэдэг']

In [68]:
from torch.utils.data import Dataset, DataLoader

class TextDataset(Dataset):

    def __init__(self, sorted_sentences, unsorted_sentences):
        self.sorted_sentences = sorted_sentences
        self.unsorted_sentences = unsorted_sentences

    def __len__(self):
        return len(self.sorted_sentences)

    def __getitem__(self, idx):
        return self.sorted_sentences[idx], self.unsorted_sentences[idx]

In [69]:
dataset = TextDataset(sorted_sentences, unsorted_sentences)

In [70]:
len(dataset)

3

In [71]:
dataset[1]

('байна', 'бн')

In [72]:
batch_size = 3 
train_loader = DataLoader(dataset, batch_size)
iterator = iter(train_loader)

In [73]:
for batch_num, batch in enumerate(iterator):
    print(batch)
    if batch_num > 3:
        break


[('байна уу', 'байна', 'миний нэрийг эрхэмхүү гэдэг'), ('бну', 'бн', 'минии нэрииг эрхэмхүү гэдэг')]


In [74]:
def tokenize(sentence, language_to_index, start_token=True, end_token=True):
    sentence_word_indicies = [language_to_index[token] for token in list(sentence)]
    if start_token:
        sentence_word_indicies.insert(0, language_to_index[START_TOKEN])
    if end_token:
        sentence_word_indicies.append(language_to_index[END_TOKEN])
    for _ in range(len(sentence_word_indicies), max_sequence_length):
        sentence_word_indicies.append(language_to_index[PADDING_TOKEN])
    return torch.tensor(sentence_word_indicies)
     

In [75]:
batch

[('байна уу', 'байна', 'миний нэрийг эрхэмхүү гэдэг'),
 ('бну', 'бн', 'минии нэрииг эрхэмхүү гэдэг')]

In [76]:
sorted_tokenized, unsorted_tokenized = [], []
for sentence_num in range(batch_size):
    sorted_sentence, unsorted_sentence = batch[0][sentence_num], batch[1][sentence_num]
    sorted_tokenized.append( tokenize(sorted_sentence, sorted_to_index, start_token=False, end_token=False) )
    unsorted_tokenized.append( tokenize(unsorted_sentence, unsorted_to_index, start_token=True, end_token=True) )
sorted_tokenized = torch.stack(sorted_tokenized)
unsorted_tokenized = torch.stack(unsorted_tokenized)

In [77]:
sorted_tokenized

tensor([[34, 33, 43, 47, 33,  1, 54, 54, 69, 69, 69, 69, 69, 69, 69, 69, 69, 69,
         69, 69, 69, 69, 69, 69, 69, 69, 69, 69, 69, 69, 69, 69, 69, 69, 69, 69,
         69, 69, 69, 69, 69, 69, 69, 69, 69, 69, 69, 69, 69, 69, 69, 69, 69, 69,
         69, 69, 69, 69, 69, 69, 69, 69, 69, 69, 69, 69, 69, 69, 69, 69, 69, 69,
         69, 69, 69, 69, 69, 69, 69, 69, 69, 69, 69, 69, 69, 69, 69, 69, 69, 69,
         69, 69, 69, 69, 69, 69, 69, 69, 69, 69, 69, 69, 69, 69, 69, 69, 69, 69,
         69, 69, 69, 69, 69, 69, 69, 69, 69, 69, 69, 69, 69, 69, 69, 69, 69, 69,
         69, 69, 69, 69, 69, 69, 69, 69, 69, 69, 69, 69, 69, 69, 69, 69, 69, 69,
         69, 69, 69, 69, 69, 69, 69, 69, 69, 69, 69, 69, 69, 69, 69, 69, 69, 69,
         69, 69, 69, 69, 69, 69, 69, 69, 69, 69, 69, 69, 69, 69, 69, 69, 69, 69,
         69, 69, 69, 69, 69, 69, 69, 69, 69, 69, 69, 69, 69, 69, 69, 69, 69, 69,
         69, 69],
        [34, 33, 43, 47, 33, 69, 69, 69, 69, 69, 69, 69, 69, 69, 69, 69, 69, 69,
         6

In [78]:
NEG_INFTY = -1e9

def create_masks(sorted_batch, unsorted_batch):
    num_sentences = len(sorted_batch)
    look_ahead_mask = torch.full([max_sequence_length, max_sequence_length] , True)
    look_ahead_mask = torch.triu(look_ahead_mask, diagonal=1)
    encoder_padding_mask = torch.full([num_sentences, max_sequence_length, max_sequence_length] , False)
    decoder_padding_mask_self_attention = torch.full([num_sentences, max_sequence_length, max_sequence_length] , False)
    decoder_padding_mask_cross_attention = torch.full([num_sentences, max_sequence_length, max_sequence_length] , False)

    for idx in range(num_sentences):
      sorted_sentence_length, unsorted_sentence_length = len(sorted_batch[idx]), len(unsorted_batch[idx])
      sorted_chars_to_padding_mask = np.arange(sorted_sentence_length + 1, max_sequence_length)
      unsorted_chars_to_padding_mask = np.arange(unsorted_sentence_length + 1, max_sequence_length)
      encoder_padding_mask[idx, :, sorted_chars_to_padding_mask] = True
      encoder_padding_mask[idx, sorted_chars_to_padding_mask, :] = True
      decoder_padding_mask_self_attention[idx, :, unsorted_chars_to_padding_mask] = True
      decoder_padding_mask_self_attention[idx, unsorted_chars_to_padding_mask, :] = True
      decoder_padding_mask_cross_attention[idx, :, sorted_chars_to_padding_mask] = True
      decoder_padding_mask_cross_attention[idx, unsorted_chars_to_padding_mask, :] = True

    encoder_self_attention_mask = torch.where(encoder_padding_mask, NEG_INFTY, 0)
    decoder_self_attention_mask =  torch.where(look_ahead_mask + decoder_padding_mask_self_attention, NEG_INFTY, 0)
    decoder_cross_attention_mask = torch.where(decoder_padding_mask_cross_attention, NEG_INFTY, 0)
    print(f"encoder_self_attention_mask {encoder_self_attention_mask.size()}: {encoder_self_attention_mask[0, :10, :10]}")
    print(f"decoder_self_attention_mask {decoder_self_attention_mask.size()}: {decoder_self_attention_mask[0, :10, :10]}")
    print(f"decoder_cross_attention_mask {decoder_cross_attention_mask.size()}: {decoder_cross_attention_mask[0, :10, :10]}")
    return encoder_self_attention_mask, decoder_self_attention_mask, decoder_cross_attention_mask

In [79]:
create_masks(batch[0], batch[1])

encoder_self_attention_mask torch.Size([3, 200, 200]): tensor([[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00, -1.0000e+09],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00, -1.0000e+09],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00, -1.0000e+09],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00, -1.0000e+09],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00, -1.0000e+09],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00, -1.0000e+09],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.

(tensor([[[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ..., -1.0000e+09,
           -1.0000e+09, -1.0000e+09],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ..., -1.0000e+09,
           -1.0000e+09, -1.0000e+09],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ..., -1.0000e+09,
           -1.0000e+09, -1.0000e+09],
          ...,
          [-1.0000e+09, -1.0000e+09, -1.0000e+09,  ..., -1.0000e+09,
           -1.0000e+09, -1.0000e+09],
          [-1.0000e+09, -1.0000e+09, -1.0000e+09,  ..., -1.0000e+09,
           -1.0000e+09, -1.0000e+09],
          [-1.0000e+09, -1.0000e+09, -1.0000e+09,  ..., -1.0000e+09,
           -1.0000e+09, -1.0000e+09]],
 
         [[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ..., -1.0000e+09,
           -1.0000e+09, -1.0000e+09],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ..., -1.0000e+09,
           -1.0000e+09, -1.0000e+09],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ..., -1.0000e+09,
           -1.0000e+09, -1.0000e+09],
          ...,
    

In [80]:
class SentenceEmbedding(nn.Module):
    def __init__(self, max_sequence_length, d_model, language_to_index, START_TOKEN, END_TOKEN, PADDING_TOKEN):
        super().__init__()
        self.vocab_size = len(language_to_index)
        self.max_sequence_length = max_sequence_length
        self.embedding = nn.Embedding(self.vocab_size, d_model)
        self.language_to_index = language_to_index
        self.position_encoder = PositionalEncoding(d_model, max_sequence_length)
        self.dropout = nn.Dropout(p=0.1)
        self.START_TOKEN = START_TOKEN
        self.END_TOKEN = END_TOKEN
        self.PADDING_TOKEN = PADDING_TOKEN
    
    def batch_tokenize(self, batch, start_token=True, end_token=True):

        def tokenize(sentence, start_token=True, end_token=True):
            sentence_word_indicies = [self.language_to_index[token] for token in list(sentence)]
            if start_token:
                sentence_word_indicies.insert(0, self.language_to_index[self.START_TOKEN])
            if end_token:
                sentence_word_indicies.append(self.language_to_index[self.END_TOKEN])
            for _ in range(len(sentence_word_indicies), self.max_sequence_length):
                sentence_word_indicies.append(self.language_to_index[self.PADDING_TOKEN])
            return torch.tensor(sentence_word_indicies)

        tokenized = []
        for sentence_num in range(len(batch)):
           tokenized.append( tokenize(batch[sentence_num], start_token, end_token) )
        tokenized = torch.stack(tokenized)
        return tokenized.to(get_device())
    
    def forward(self, x, end_token=True): # sentence
        x = self.batch_tokenize(x ,end_token)
        x = self.embedding(x)
        pos = self.position_encoder().to(get_device())
        x = self.dropout(x + pos)
        return x