In [1]:
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
import math


In [2]:
def get_device():
    return torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

def scaled_dot_product(q, k, v, mask=None):
    d_k = q.size()[-1]
    scaled = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(d_k)
    if mask is not None:
        scaled = scaled.permute(1, 0, 2, 3) + mask
        scaled = scaled.permute(1, 0, 2, 3)
    attention = F.softmax(scaled, dim=-1)
    values = torch.matmul(attention, v)
    return values, attention

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_sequence_length):
        super().__init__()
        self.max_sequence_length = max_sequence_length
        self.d_model = d_model

    def forward(self):
        even_i = torch.arange(0, self.d_model, 2).float()
        denominator = torch.pow(10000, even_i/self.d_model)
        position = (torch.arange(self.max_sequence_length)
                          .reshape(self.max_sequence_length, 1))
        even_PE = torch.sin(position / denominator)
        odd_PE = torch.cos(position / denominator)
        stacked = torch.stack([even_PE, odd_PE], dim=2)
        PE = torch.flatten(stacked, start_dim=1, end_dim=2)
        return PE

class SentenceEmbedding(nn.Module):
    "For a given sentence, create an embedding"
    def __init__(self, max_sequence_length, d_model, language_to_index, START_TOKEN, END_TOKEN, PADDING_TOKEN):
        super().__init__()
        self.vocab_size = len(language_to_index)
        self.max_sequence_length = max_sequence_length
        self.embedding = nn.Embedding(self.vocab_size, d_model)
        self.language_to_index = language_to_index
        self.position_encoder = PositionalEncoding(d_model, max_sequence_length)
        self.dropout = nn.Dropout(p=0.1)
        self.START_TOKEN = START_TOKEN
        self.END_TOKEN = END_TOKEN
        self.PADDING_TOKEN = PADDING_TOKEN
    
    def batch_tokenize(self, batch, start_token, end_token):

        def tokenize(sentence, start_token, end_token):
            
            sentence_word_indicies = [self.language_to_index[token] for token in list(sentence)]
            if start_token:
                sentence_word_indicies.insert(0, self.language_to_index[self.START_TOKEN])
            if end_token:
                sentence_word_indicies.append(self.language_to_index[self.END_TOKEN])
            for _ in range(len(sentence_word_indicies), self.max_sequence_length):
                sentence_word_indicies.append(self.language_to_index[self.PADDING_TOKEN])
            return torch.tensor(sentence_word_indicies)

        tokenized = []
        for sentence_num in range(len(batch)):
           tokenized.append( tokenize(batch[sentence_num], start_token, end_token) )
        tokenized = torch.stack(tokenized)
        return tokenized.to(get_device())
    
    def forward(self, x, start_token, end_token): # sentence
        x = self.batch_tokenize(x, start_token, end_token)
        x = self.embedding(x)
        pos = self.position_encoder().to(get_device())
        x = self.dropout(x + pos)
        return x


class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        self.qkv_layer = nn.Linear(d_model , 3 * d_model)
        self.linear_layer = nn.Linear(d_model, d_model)
    
    def forward(self, x, mask):
        batch_size, sequence_length, d_model = x.size()
        qkv = self.qkv_layer(x)
        qkv = qkv.reshape(batch_size, sequence_length, self.num_heads, 3 * self.head_dim)
        qkv = qkv.permute(0, 2, 1, 3)
        q, k, v = qkv.chunk(3, dim=-1)
        values, attention = scaled_dot_product(q, k, v, mask)
        values = values.permute(0, 2, 1, 3).reshape(batch_size, sequence_length, self.num_heads * self.head_dim)
        out = self.linear_layer(values)
        return out


class LayerNormalization(nn.Module):
    def __init__(self, parameters_shape, eps=1e-5):
        super().__init__()
        self.parameters_shape=parameters_shape
        self.eps=eps
        self.gamma = nn.Parameter(torch.ones(parameters_shape))
        self.beta =  nn.Parameter(torch.zeros(parameters_shape))

    def forward(self, inputs):
        dims = [-(i + 1) for i in range(len(self.parameters_shape))]
        mean = inputs.mean(dim=dims, keepdim=True)
        var = ((inputs - mean) ** 2).mean(dim=dims, keepdim=True)
        std = (var + self.eps).sqrt()
        y = (inputs - mean) / std
        out = self.gamma * y + self.beta
        return out

  
class PositionwiseFeedForward(nn.Module):
    def __init__(self, d_model, hidden, drop_prob=0.1):
        super(PositionwiseFeedForward, self).__init__()
        self.linear1 = nn.Linear(d_model, hidden)
        self.linear2 = nn.Linear(hidden, d_model)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(p=drop_prob)

    def forward(self, x):
        x = self.linear1(x)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.linear2(x)
        return x


class EncoderLayer(nn.Module):
    def __init__(self, d_model, ffn_hidden, num_heads, drop_prob):
        super(EncoderLayer, self).__init__()
        self.attention = MultiHeadAttention(d_model=d_model, num_heads=num_heads)
        self.norm1 = LayerNormalization(parameters_shape=[d_model])
        self.dropout1 = nn.Dropout(p=drop_prob)
        self.ffn = PositionwiseFeedForward(d_model=d_model, hidden=ffn_hidden, drop_prob=drop_prob)
        self.norm2 = LayerNormalization(parameters_shape=[d_model])
        self.dropout2 = nn.Dropout(p=drop_prob)

    def forward(self, x, self_attention_mask):
        residual_x = x.clone()
        x = self.attention(x, mask=self_attention_mask)
        x = self.dropout1(x)
        x = self.norm1(x + residual_x)
        residual_x = x.clone()
        x = self.ffn(x)
        x = self.dropout2(x)
        x = self.norm2(x + residual_x)
        return x
    
class SequentialEncoder(nn.Sequential):
    def forward(self, *inputs):
        x, self_attention_mask  = inputs
        for module in self._modules.values():
            x = module(x, self_attention_mask)
        return x

class Encoder(nn.Module):
    def __init__(self, 
                 d_model, 
                 ffn_hidden, 
                 num_heads, 
                 drop_prob, 
                 num_layers,
                 max_sequence_length,
                 language_to_index,
                 START_TOKEN,
                 END_TOKEN, 
                 PADDING_TOKEN):
        super().__init__()
        self.sentence_embedding = SentenceEmbedding(max_sequence_length, d_model, language_to_index, START_TOKEN, END_TOKEN, PADDING_TOKEN)
        self.layers = SequentialEncoder(*[EncoderLayer(d_model, ffn_hidden, num_heads, drop_prob)
                                      for _ in range(num_layers)])

    def forward(self, x, self_attention_mask, start_token, end_token):
        x = self.sentence_embedding(x, start_token, end_token)
        x = self.layers(x, self_attention_mask)
        return x


class MultiHeadCrossAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        self.kv_layer = nn.Linear(d_model , 2 * d_model)
        self.q_layer = nn.Linear(d_model , d_model)
        self.linear_layer = nn.Linear(d_model, d_model)
    
    def forward(self, x, y, mask):
        batch_size, sequence_length, d_model = x.size() # in practice, this is the same for both languages...so we can technically combine with normal attention
        kv = self.kv_layer(x)
        q = self.q_layer(y)
        kv = kv.reshape(batch_size, sequence_length, self.num_heads, 2 * self.head_dim)
        q = q.reshape(batch_size, sequence_length, self.num_heads, self.head_dim)
        kv = kv.permute(0, 2, 1, 3)
        q = q.permute(0, 2, 1, 3)
        k, v = kv.chunk(2, dim=-1)
        values, attention = scaled_dot_product(q, k, v, mask) # We don't need the mask for cross attention, removing in outer function!
        values = values.permute(0, 2, 1, 3).reshape(batch_size, sequence_length, d_model)
        out = self.linear_layer(values)
        return out


class DecoderLayer(nn.Module):
    def __init__(self, d_model, ffn_hidden, num_heads, drop_prob):
        super(DecoderLayer, self).__init__()
        self.self_attention = MultiHeadAttention(d_model=d_model, num_heads=num_heads)
        self.layer_norm1 = LayerNormalization(parameters_shape=[d_model])
        self.dropout1 = nn.Dropout(p=drop_prob)

        self.encoder_decoder_attention = MultiHeadCrossAttention(d_model=d_model, num_heads=num_heads)
        self.layer_norm2 = LayerNormalization(parameters_shape=[d_model])
        self.dropout2 = nn.Dropout(p=drop_prob)

        self.ffn = PositionwiseFeedForward(d_model=d_model, hidden=ffn_hidden, drop_prob=drop_prob)
        self.layer_norm3 = LayerNormalization(parameters_shape=[d_model])
        self.dropout3 = nn.Dropout(p=drop_prob)

    def forward(self, x, y, self_attention_mask, cross_attention_mask):
        _y = y.clone()
        y = self.self_attention(y, mask=self_attention_mask)
        y = self.dropout1(y)
        y = self.layer_norm1(y + _y)

        _y = y.clone()
        y = self.encoder_decoder_attention(x, y, mask=cross_attention_mask)
        y = self.dropout2(y)
        y = self.layer_norm2(y + _y)

        _y = y.clone()
        y = self.ffn(y)
        y = self.dropout3(y)
        y = self.layer_norm3(y + _y)
        return y


class SequentialDecoder(nn.Sequential):
    def forward(self, *inputs):
        x, y, self_attention_mask, cross_attention_mask = inputs
        for module in self._modules.values():
            y = module(x, y, self_attention_mask, cross_attention_mask)
        return y

class Decoder(nn.Module):
    def __init__(self, 
                 d_model, 
                 ffn_hidden, 
                 num_heads, 
                 drop_prob, 
                 num_layers,
                 max_sequence_length,
                 language_to_index,
                 START_TOKEN,
                 END_TOKEN, 
                 PADDING_TOKEN):
        super().__init__()
        self.sentence_embedding = SentenceEmbedding(max_sequence_length, d_model, language_to_index, START_TOKEN, END_TOKEN, PADDING_TOKEN)
        self.layers = SequentialDecoder(*[DecoderLayer(d_model, ffn_hidden, num_heads, drop_prob) for _ in range(num_layers)])

    def forward(self, x, y, self_attention_mask, cross_attention_mask, start_token, end_token):
        y = self.sentence_embedding(y, start_token, end_token)
        y = self.layers(x, y, self_attention_mask, cross_attention_mask)
        return y


class Transformer(nn.Module):
    def __init__(self, 
                d_model, 
                ffn_hidden, 
                num_heads, 
                drop_prob, 
                num_layers,
                max_sequence_length, 
                jpn_vocab_size,
                eng_to_index,
                jpn_to_index,
                START_TOKEN, 
                END_TOKEN, 
                PADDING_TOKEN
                ):
        super().__init__()
        self.encoder = Encoder(d_model, ffn_hidden, num_heads, drop_prob, num_layers, max_sequence_length, eng_to_index, START_TOKEN, END_TOKEN, PADDING_TOKEN)
        self.decoder = Decoder(d_model, ffn_hidden, num_heads, drop_prob, num_layers, max_sequence_length, jpn_to_index, START_TOKEN, END_TOKEN, PADDING_TOKEN)
        self.linear = nn.Linear(d_model, jpn_vocab_size)
        self.device = torch.device('cpu')

    def forward(self, 
                x, 
                y, 
                encoder_self_attention_mask=None, 
                decoder_self_attention_mask=None, 
                decoder_cross_attention_mask=None,
                enc_start_token=False,
                enc_end_token=False,
                dec_start_token=True, # We should make this true
                dec_end_token=False): # x, y are batch of sentences
        x = self.encoder(x, encoder_self_attention_mask, start_token=enc_start_token, end_token=enc_end_token)
        out = self.decoder(x, y, decoder_self_attention_mask, decoder_cross_attention_mask, start_token=dec_start_token, end_token=dec_end_token)
        out = self.linear(out)
        return out

In [3]:
sentences_eng = []
eng_vocab = []
sentences_jpn = []
jpn_vocab = []
counter = 0

START_TOKEN = '<sos>'
PADDING_TOKEN = '\0'
END_TOKEN = '<eos>'


eng_vocab_set = set(eng_vocab)
jpn_vocab_set = set(jpn_vocab)

counter = 0

with open('jpn.txt', 'r', encoding= 'utf-8') as file:
    for line in file:
        sentences = line.split('\t')
        english_sentence = sentences[0]
        japanese_sentence = sentences[1]

        sentences_jpn.append(japanese_sentence)
        sentences_eng.append(english_sentence)

        for l in english_sentence:
            if l not in eng_vocab_set:
                eng_vocab_set.add(l)
                eng_vocab.append(l)
        
        for k in japanese_sentence:
            if k not in jpn_vocab_set:
                jpn_vocab_set.add(k)
                jpn_vocab.append(k)

eng_vocab.sort()
jpn_vocab.sort()

jpn_vocab = [char for char in jpn_vocab if not (char.isalpha() and char.islower())]
jpn_vocab = [char for char in jpn_vocab if not (char.isalpha() and char.isupper())]
jpn_vocab = [char for char in jpn_vocab if char != '\u3000']

eng_vocab.insert(0, START_TOKEN)
eng_vocab.append(PADDING_TOKEN)
eng_vocab.append(END_TOKEN)

jpn_vocab.insert(0, START_TOKEN)
jpn_vocab.append(PADDING_TOKEN)
jpn_vocab.append(END_TOKEN)

print(jpn_vocab)
print(eng_vocab)

['<sos>', ' ', '!', '"', '%', '&', "'", ',', '-', '.', '/', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '?', '@', '—', '₂', '℃', '√', '、', '。', '々', '〆', '「', '」', '『', '』', '〜', 'ぁ', 'あ', 'い', 'ぅ', 'う', 'ぇ', 'え', 'ぉ', 'お', 'か', 'が', 'き', 'ぎ', 'く', 'ぐ', 'け', 'げ', 'こ', 'ご', 'さ', 'ざ', 'し', 'じ', 'す', 'ず', 'せ', 'ぜ', 'そ', 'ぞ', 'た', 'だ', 'ち', 'ぢ', 'っ', 'つ', 'づ', 'て', 'で', 'と', 'ど', 'な', 'に', 'ぬ', 'ね', 'の', 'は', 'ば', 'ぱ', 'ひ', 'び', 'ぴ', 'ふ', 'ぶ', 'ぷ', 'へ', 'べ', 'ぺ', 'ほ', 'ぼ', 'ぽ', 'ま', 'み', 'む', 'め', 'も', 'ゃ', 'や', 'ゅ', 'ゆ', 'ょ', 'よ', 'ら', 'り', 'る', 'れ', 'ろ', 'わ', 'ゐ', 'ゑ', 'を', 'ん', '゜', 'ァ', 'ア', 'ィ', 'イ', 'ゥ', 'ウ', 'ェ', 'エ', 'ォ', 'オ', 'カ', 'ガ', 'キ', 'ギ', 'ク', 'グ', 'ケ', 'ゲ', 'コ', 'ゴ', 'サ', 'ザ', 'シ', 'ジ', 'ス', 'ズ', 'セ', 'ゼ', 'ソ', 'ゾ', 'タ', 'ダ', 'チ', 'ッ', 'ツ', 'テ', 'デ', 'ト', 'ド', 'ナ', 'ニ', 'ヌ', 'ネ', 'ノ', 'ハ', 'バ', 'パ', 'ヒ', 'ビ', 'ピ', 'フ', 'ブ', 'プ', 'ヘ', 'ベ', 'ペ', 'ホ', 'ボ', 'ポ', 'マ', 'ミ', 'ム', 'メ', 'モ', 'ャ', 'ヤ', 'ュ', 'ユ', 'ョ', 'ヨ', 'ラ', 'リ', 'ル', 'レ', 'ロ', 'ワ', 'ン', 'ヴ', 'ヵ', 'ヶ', '・', 

In [4]:
index_to_jpn = {k:v for k,v in enumerate(jpn_vocab)}
jpn_to_index = {v:k for k,v in enumerate(jpn_vocab)}
index_to_eng = {k:v for k,v in enumerate(eng_vocab)}
eng_to_index = {v:k for k,v in enumerate(eng_vocab)}

sentences_eng = [sentence.rstrip('\n').lower() for sentence in sentences_eng]
sentences_jpn = [sentence.rstrip('\n') for sentence in sentences_jpn]

PERCENTILE = 97
print(f"{PERCENTILE}th percentile length Japanese: {np.percentile([len(s) for s in sentences_jpn], PERCENTILE)}")
print(f"{PERCENTILE}th percentile length English: {np.percentile([len(s) for s in sentences_eng], PERCENTILE)}")

print(sentences_eng[:10])
print(sentences_jpn[:10])

97th percentile length Japanese: 28.0
97th percentile length English: 60.0
['go.', 'go.', 'hi.', 'hi.', 'hi.', 'hi.', 'run.', 'run.', 'who?', 'wow!']
['行け。', '行きなさい。', 'こんにちは。', 'もしもし。', 'やっほー。', 'こんにちは！', '走れ。', '走って！', '誰？', 'すごい！']


In [5]:
max_sequence_length = 200

def is_valid_tokens(sentence, vocab):
    for token in list(set(sentence)):
        if token not in vocab:
            return False
    return True

def is_valid_length(sentence, max_sequence_length):
    return len(list(sentence)) < (max_sequence_length - 1) # need to re-add the end token so leaving 1 space

valid_sentence_indicies = []
for index in range(len(sentences_jpn)):
    japanese_sentence, english_sentence = sentences_jpn[index], sentences_eng[index]
    if is_valid_length(japanese_sentence, max_sequence_length) \
      and is_valid_length(english_sentence, max_sequence_length) \
      and is_valid_tokens(japanese_sentence, jpn_vocab):
        valid_sentence_indicies.append(index)

print(f"Number of sentences: {len(sentences_jpn)}")
print(f"Number of valid sentences: {len(valid_sentence_indicies)}")

Number of sentences: 104785
Number of valid sentences: 104493


In [6]:
japanese_sentences = [sentences_jpn[i] for i in valid_sentence_indicies]
english_sentences = [sentences_eng[i] for i in valid_sentence_indicies]

print(len(english_sentences))
print(len(japanese_sentences))

104493
104493


In [7]:
d_model = 512
batch_size = 30
ffn_hidden = 2048
num_heads = 8
drop_prob = 0.1
num_layers = 1
max_sequence_length = 200
jpn_vocab_size = len(jpn_vocab)
transformer = Transformer(d_model, 
                          ffn_hidden,
                          num_heads, 
                          drop_prob, 
                          num_layers, 
                          max_sequence_length,
                          jpn_vocab_size,
                          eng_to_index,
                          jpn_to_index,
                          START_TOKEN, 
                          END_TOKEN, 
                          PADDING_TOKEN)

In [8]:
transformer

Transformer(
  (encoder): Encoder(
    (sentence_embedding): SentenceEmbedding(
      (embedding): Embedding(92, 512)
      (position_encoder): PositionalEncoding()
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (layers): SequentialEncoder(
      (0): EncoderLayer(
        (attention): MultiHeadAttention(
          (qkv_layer): Linear(in_features=512, out_features=1536, bias=True)
          (linear_layer): Linear(in_features=512, out_features=512, bias=True)
        )
        (norm1): LayerNormalization()
        (dropout1): Dropout(p=0.1, inplace=False)
        (ffn): PositionwiseFeedForward(
          (linear1): Linear(in_features=512, out_features=2048, bias=True)
          (linear2): Linear(in_features=2048, out_features=512, bias=True)
          (relu): ReLU()
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (norm2): LayerNormalization()
        (dropout2): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (decoder): Decoder(
    (sentence_embedding):

In [9]:
class TextDataset(Dataset):
    def __init__(self, english_sentences, japanese_sentences):
        self.english_sentences = english_sentences
        self.japanese_sentences = japanese_sentences

    def __len__(self):
        return len(self.english_sentences)
    
    def __getitem__(self, index):
        return self.english_sentences[index], self.japanese_sentences[index]
    
dataset = TextDataset(english_sentences, japanese_sentences)

train_loader = DataLoader(dataset, batch_size)
iterator = iter(train_loader)

for batch_num, batch in enumerate(iterator):
    print(batch)
    if(batch_num > 3):
        break



[('go.', 'go.', 'hi.', 'hi.', 'hi.', 'hi.', 'run.', 'run.', 'who?', 'wow!', 'wow!', 'wow!', 'wow!', 'duck!', 'duck!', 'fire!', 'fire!', 'fire!', 'help!', 'help!', 'hide.', 'jump!', 'jump!', 'jump!', 'jump!', 'jump!', 'jump.', 'jump.', 'jump.', 'jump.'), ('行け。', '行きなさい。', 'こんにちは。', 'もしもし。', 'やっほー。', 'こんにちは！', '走れ。', '走って！', '誰？', 'すごい！', 'ワォ！', 'わぉ！', 'おー！', '頭を下げろ！', '伏せて！', '火事だ！', '火事！', '撃て！', '助けて！', '助けてくれ！', '隠れろ。', '飛び越えろ！', '跳べ！', '飛び降りろ！', '飛び跳ねて！', 'ジャンプして！', '跳べ！', '飛び跳ねて！', 'ジャンプして！', '跳んで。')]
[('stop!', 'stop!', 'wait!', 'wait!', 'wait!', 'wait.', 'wait.', 'go on.', 'go on.', 'go on.', 'go on.', 'hello!', 'hello!', 'hello!', 'hello.', 'hello.', 'hurry!', 'i see.', 'i see.', 'i see.', 'i see.', 'i see.', 'i see.', 'i see.', 'i try.', 'i try.', 'i try.', 'i try.', 'i try.', 'i won!'), ('やめろ！', '止まれ！', '待って！', '待ってろよ。', '待ってくれ。', '待ってろよ。', '待ってくれ。', '続けて。', '進んで。', '進め。', '続けろ。', 'こんにちは。', 'もしもし。', 'こんにちは！', 'もしもし。', 'やあ！', '急げ！', 'なるほど。', 'なるほどね。', 'わかった。', 'わかりました。', 'そうですか

In [10]:
criterian = nn.CrossEntropyLoss(ignore_index = jpn_to_index[PADDING_TOKEN],
                                reduction = 'none')
for params in transformer.parameters():
    if params.dim() > 1:
        nn.init.xavier_uniform_(params)

optim = torch.optim.Adam(transformer.parameters(), lr = 1e-4)

device = torch.device('cpu')

In [11]:
NEG_INFINITY = -1e9

def create_masks(eng_batch, jpn_batch):
    num_sentences = len(eng_batch)
    look_ahead_mask = torch.full([max_sequence_length, max_sequence_length], True)
    look_ahead_mask = torch.triu(look_ahead_mask, diagonal = 1)
    encoder_padding_mask = torch.full([num_sentences, max_sequence_length, max_sequence_length], False)
    decoder_padding_mask_self_attention = torch.full([num_sentences, max_sequence_length, max_sequence_length] , False)
    decoder_padding_mask_cross_attention = torch.full([num_sentences, max_sequence_length, max_sequence_length] , False)

    for index in range(num_sentences):
        eng_sentence_length, jpn_sentence_length = len(eng_batch[index]), len(jpn_batch[index])
        eng_chars_to_padding_mask = np.arange(eng_sentence_length + 1, max_sequence_length)
        jpn_chars_to_padding_mask = np.arange(jpn_sentence_length + 1, max_sequence_length)
        encoder_padding_mask[index, :, eng_chars_to_padding_mask] = True
        encoder_padding_mask[index, eng_chars_to_padding_mask, :] = True
        decoder_padding_mask_self_attention[index, :, jpn_chars_to_padding_mask] = True
        decoder_padding_mask_self_attention[index, jpn_chars_to_padding_mask, :] = True
        decoder_padding_mask_cross_attention[index, :, eng_chars_to_padding_mask] = True
        decoder_padding_mask_cross_attention[index, jpn_chars_to_padding_mask, :] = True

    encoder_self_attention_mask = torch.where(encoder_padding_mask, NEG_INFINITY, 0)
    decoder_self_attention_mask =  torch.where(look_ahead_mask + decoder_padding_mask_self_attention, NEG_INFINITY, 0)
    decoder_cross_attention_mask = torch.where(decoder_padding_mask_cross_attention, NEG_INFINITY, 0)
    return encoder_self_attention_mask, decoder_self_attention_mask, decoder_cross_attention_mask


In [12]:
transformer.train()
transformer.to(device)
total_loss = 0
num_epochs = 10

for epoch in range(num_epochs):
    print(f"Epoch {epoch}")
    iterator = iter(train_loader)
    for batch_num, batch in enumerate(iterator):
        transformer.train()
        eng_batch, jpn_batch = batch
        encoder_self_attention_mask, decoder_self_attention_mask, decoder_cross_attention_mask = create_masks(eng_batch, jpn_batch)
        optim.zero_grad()
        jpn_predictions = transformer(eng_batch,
                                     jpn_batch,
                                     encoder_self_attention_mask.to(device), 
                                     decoder_self_attention_mask.to(device), 
                                     decoder_cross_attention_mask.to(device),
                                     enc_start_token=False,
                                     enc_end_token=False,
                                     dec_start_token=True,
                                     dec_end_token=True)
        
        labels = transformer.decoder.sentence_embedding.batch_tokenize(jpn_batch, start_token=False, end_token=True)
        loss = criterian(
            jpn_predictions.view(-1, jpn_vocab_size).to(device),
            labels.view(-1).to(device)
        ).to(device)
        valid_indicies = torch.where(labels.view(-1) == jpn_to_index[PADDING_TOKEN], False, True)
        loss = loss.sum() / valid_indicies.sum()
        loss.backward()
        optim.step()
        #train_losses.append(loss.item())
        if batch_num % 100 == 0:
            print(f"Iteration {batch_num} : {loss.item()}")
            print(f"English: {eng_batch[0]}")
            print(f"Japanese Translation: {jpn_batch[0]}")
            jpn_sentence_predicted = torch.argmax(jpn_predictions[0], axis=1)
            predicted_sentence = ""
            for index in jpn_sentence_predicted:
              if index == jpn_to_index[END_TOKEN]:
                break
              predicted_sentence += index_to_jpn[index.item()]
            print(f"Japanese Prediction: {predicted_sentence}")


            transformer.eval()
            jpn_sentence = ("",)
            eng_sentence = ("should we go to the mall?",)
            for word_counter in range(max_sequence_length):
                encoder_self_attention_mask, decoder_self_attention_mask, decoder_cross_attention_mask= create_masks(eng_sentence, jpn_sentence)
                predictions = transformer(eng_sentence,
                                          jpn_sentence,
                                          encoder_self_attention_mask.to(device), 
                                          decoder_self_attention_mask.to(device), 
                                          decoder_cross_attention_mask.to(device),
                                          enc_start_token=False,
                                          enc_end_token=False,
                                          dec_start_token=True,
                                          dec_end_token=False)
                next_token_prob_distribution = predictions[0][word_counter] # not actual probs
                next_token_index = torch.argmax(next_token_prob_distribution).item()
                next_token = index_to_jpn[next_token_index]
                jpn_sentence = (jpn_sentence[0] + next_token, )
                if next_token == END_TOKEN:
                  break
            
            print(f"Evaluation translation (should we go to the mall?) : {jpn_sentence}")
            print("-------------------------------------------")

Epoch 0
Iteration 0 : 8.143197059631348
English: go.
Japanese Translation: 行け。
Japanese Prediction: 然価然価停停然停儲儲停鉱痛痛処函処廊函函儲函儲函処処処処函函函函函函函函函函頑函或偶念念察念念念念函儲頑儲涙停儲林林函林林頑儲函儲或函函函函函函函函函函函易函易函函函易函函函躾函函函易偶携函函函或函函函函函函函函函函函函函函函函函函函偶函函函函虎処函易函虎偶偶函函函函函易函函函函易函函函函偶函偶函函停函函活案曜縄函函函函函函函縄儲澄処泣縄函函縄処函函函函函憤印函躍函函躍函覇議函乱議携板函函函函虎函
Evaluation translation (should we go to the mall?) : ('詫然詫然然然然然外外！退退詫然然然安安安運然然呟呟処処処処処処争争争函！！！！呟呟呟呟呟呟！！！呟呟呟頑頑頑携呟呟呟呟呟携携函呟呟呟呟呟函函函函詫詫詫函鶴函函呟呟函函函函函函函函函呟呟携携携鋼鋼鋼函呟呟呟函函函函函函函函函函函函函函函函鞄鞄鞄虎虎函函函容呟呟呟呟函函函函函函函函函函呟呟呟呟呟呟函函函函呟呟活活活活函函函函直直直直直容容容容容容退退活活活活活活活詫詫詫躍退函函函呟呟呟呟呟函函函携詫詫',)
-------------------------------------------


KeyboardInterrupt: 