In [3]:
import urllib
import random
import os
import collections
import math
import numpy as np
from tqdm import tqdm
import torch
from torch import nn
from torch.nn import functional as F
import torchtext
from torch.utils.data import Dataset, DataLoader, random_split

In [4]:
os.makedirs('data', exist_ok=True)
dataset_train, dataset_test = torchtext.datasets.Multi30k(root='./data', split=('train', 'test'), language_pair=('de', 'en'))
tokenizer_en = torchtext.data.utils.get_tokenizer('spacy', language="en_core_web_sm")
tokenizer_de = torchtext.data.utils.get_tokenizer('spacy', language="de_core_news_sm")
dataset_train = list(dataset_train)
dataset_test = list(dataset_test)

In [5]:
tokenized_en_data = []
tokenized_de_data = []
counter_en = collections.Counter()
counter_de = collections.Counter()

print('Start Tokenizing...')
for de, en in tqdm(dataset_train):
    tokens_de = tokenizer_en(de)
    tokens_en = tokenizer_de(en)
    tokenized_de_data.append(tokens_de)
    tokenized_en_data.append(tokens_en)
    
print('Making En Vocab')
for line in tqdm(tokenized_en_data):
    counter_en.update(line)

print('Making De Vocab')
for line in tqdm(tokenized_de_data):
    counter_de.update(line)

vocab_en = torchtext.vocab.vocab(counter_en, min_freq=1, specials=["<unk>", "<pad>", "<sos>", "<eos>"])
vocab_de = torchtext.vocab.vocab(counter_de, min_freq=1, specials=["<unk>", "<pad>", "<sos>", "<eos>"])




Start Tokenizing...


100%|██████████| 29001/29001 [00:07<00:00, 4038.04it/s]


Making En Vocab


100%|██████████| 29001/29001 [00:00<00:00, 162961.48it/s]


Making De Vocab


100%|██████████| 29001/29001 [00:00<00:00, 174037.78it/s]


In [6]:
class Multi30kDataset(Dataset):
    def __init__(self, tokenized_src, tokenized_tgt, vocab_src, vocab_tgt, max_seq = 256):
        self.src = []
        self.tgt = []

        for tokens in tqdm(tokenized_src, "Src data"):
            token_ids = [vocab_src['<sos>']]
            token_ids += [vocab_src[token] for token in tokens]
            token_ids += [vocab_src['<eos>']]
            self.src.append(token_ids)

        for tokens in tqdm(tokenized_tgt, "Tgt data"):
            token_ids = [vocab_tgt['<sos>']]
            token_ids += [vocab_tgt[token] for token in tokens]
            token_ids += [vocab_tgt['<eos>']]
            self.tgt.append(token_ids)

    def __len__(self):
        return len(self.src)

    def __getitem__(self, idx):
        return torch.tensor(self.src[idx]), torch.tensor(self.tgt[idx])

In [7]:
multi30k_dataset = Multi30kDataset(tokenized_de_data, tokenized_en_data, vocab_de, vocab_en)

Src data: 100%|██████████| 29001/29001 [00:00<00:00, 48766.91it/s]
Tgt data: 100%|██████████| 29001/29001 [00:00<00:00, 54135.17it/s]


In [8]:
train_ratio = 0.8
train_size = int(train_ratio*len(multi30k_dataset))
valid_size = len(multi30k_dataset) - train_size
train_dataset, valid_dataset = random_split(multi30k_dataset, [train_size, valid_size])

In [80]:
class MultiHeadAttentionLayer(nn.Module):
    def __init__(self, hidden_size, n_heads, device):
        super().__init__()
        self.hidden_size = hidden_size
        self.n_heads = n_heads
        self.head_size = hidden_size // n_heads

        self.fc_q = nn.Linear(hidden_size, hidden_size)
        self.fc_k = nn.Linear(hidden_size, hidden_size)
        self.fc_v = nn.Linear(hidden_size, hidden_size)
        self.fc_o = nn.Linear(hidden_size, hidden_size)

        self.scale = torch.sqrt(torch.FloatTensor([self.head_size])).to(device)

    def forward(self, query, key, value, mask=None):
        #query: [batch_size, query_length, hidden_size]
        #key: [batch_size, key_length, hidden_size]
        #value: [batch_size, value_length, hidden_size]

        batch_size = query.size(0)

        Q = self.fc_q(query)
        K = self.fc_k(key)
        V = self.fc_v(value)

        #Q: [batch_size, query_length, hidden_size]
        #K: [batch_size, key_length, hidden_size]
        #V: [batch_size, value_length, hidden_size]

        Q = Q.view(batch_size, -1, self.n_heads, self.head_size).permute(0,2,1,3)
        K = K.view(batch_size, -1, self.n_heads, self.head_size).permute(0,2,1,3)
        V = V.view(batch_size, -1, self.n_heads, self.head_size).permute(0,2,1,3)
        #Q: [batch_size, n_heads, query_length, head_size]
        #K: [batch_size, n_heads, key_length, head_size]
        #V: [batch_size, n_heads, value_length, head_size]

        energy = torch.matmul(Q, K.permute(0,1,3,2)) / self.scale
        #energy: [batch_size, n_heads, query_length, key_length]

        if mask is not None:
            energy = energy.masked_fill(mask==0, -1e10)
        attention = torch.softmax(energy, dim=-1)
        #attention: [batch_size, n_heads, query_length, key_length]

        out = torch.matmul(attention, V)
        #out: [batch_size, n_heads, query_length, head_size]

        out = out.permute(0,2,1,3).contiguous()
        #out: [batch_size, query_length, n_heads, head_size]

        out = out.view(batch_size, -1, self.hidden_size)
        #out: [batch_size, query_length, hidden_size]

        out = self.fc_o(out)
        #out: [batch_size, query_length, hidden_size]

        return out, attention

In [81]:
class PositionwiseFeedforwardLayer(nn.Module):
    def __init__(self, hidden_size, pwff_size):
        super().__init__()

        self.fc1 = nn.Linear(hidden_size, pwff_size)
        self.fc2 = nn.Linear(pwff_size, hidden_size)

    def forward(self, input):
        #input: [batch_size, sequence_length, hidden_size]

        out = self.fc1(input)
        out = torch.relu(out)
        out = self.fc2(out)

        #out: [batch_size, sequence_length, hidden_size]

        return out


In [82]:
class EncoderLayer(nn.Module):
    def __init__(self, hidden_size, n_heads, pwff_size, device):
        super().__init__()

        self.self_attn_layer_norm = nn.LayerNorm(hidden_size)
        self.ff_layer_norm = nn.LayerNorm(hidden_size)
        self.self_attention = MultiHeadAttentionLayer(hidden_size, n_heads, device)
        self.positionwise_feedforward = PositionwiseFeedforwardLayer(hidden_size, pwff_size)
    
    def forward(self, input, input_mask):
        #input: [batch_size, input_size, hidden_size]
        #input_mask: [batch_size, 1, 1, input_size]

        _out, _ = self.self_attention(input, input, input, input_mask)

        out = self.self_attn_layer_norm(input + _out)

        #input: [batch_size, input_size, hidden_size]

        _out = self.positionwise_feedforward(out)

        out = self.ff_layer_norm(out + _out)

        #out: [batch_size, input_size, hidden_size]

        return out


In [83]:
#Ref: https://github.com/bentrevett/pytorch-seq2seq/blob/master/6%20-%20Attention%20is%20All%20You%20Need.ipynb
class Encoder(nn.Module):
    def __init__(self, vocab_size, hidden_size, pwff_size, n_layers, n_heads, device, max_size=100):
        super().__init__()

        self.device = device
        self.token_embedding = nn.Embedding(vocab_size, hidden_size)
        self.position_embedding = nn.Embedding(max_size, hidden_size)
        self.layers = nn.ModuleList([EncoderLayer(hidden_size, n_heads, pwff_size,device) for _ in range(n_layers)])
        self.scale = torch.sqrt(torch.FloatTensor([hidden_size])).to(device)
    def forward(self, input, input_mask):
        #input: [batch_size, input_size]
        #input_mask: [batch_size, 1, 1, input_size]

        batch_size, input_size = input.size()

        pos = torch.arange(0, input_size).unsqueeze(0).repeat(batch_size,1).to(self.device)
        #pos: [batch_size, input_size]

        out = (self.token_embedding(input) * self.scale) + self.position_embedding(pos)
        #out: [batch_size, input_size, hidden_size]

        for layer in self.layers:
            out = layer(out, input_mask)
        
        #out: [batch_size, input_size, hidden_size]

        return out


In [84]:
class DecoderLayer(nn.Module):
    def __init__(self, hidden_size, n_heads, pwff_size, device):
        super().__init__()

        self.self_attn_layer_norm = nn.LayerNorm(hidden_size)
        self.enc_dec_attn_layer_norm = nn.LayerNorm(hidden_size)
        self.ff_layer_norm = nn.LayerNorm(hidden_size)
        self.self_attention = MultiHeadAttentionLayer(hidden_size, n_heads, device)
        self.encoder_decoder_attention = MultiHeadAttentionLayer(hidden_size, n_heads, device)
        self.positionwise_feedforward = PositionwiseFeedforwardLayer(hidden_size, pwff_size)
    
    def forward(self, input, encoder_output, input_mask, encoder_mask):
        #input: [batch_size, input_size, hidden_size]
        #encoder_output = [batch_size, encoder_input_size, hidden_size]
        #input_mask: [batch_size, 1, input_size, input_size]
        #encoder_mask: [batch_size, 1, 1, encoder_input_size]

        _out, _ = self.self_attention(input, input, input, input_mask)

        out = self.self_attn_layer_norm(input + _out)

        _out, attention = self.encoder_decoder_attention(input, encoder_output, encoder_output, encoder_mask)

        out = self.enc_dec_attn_layer_norm(out + _out)

        _out = self.positionwise_feedforward(out)

        out = self.ff_layer_norm(out + _out)


        return out, attention


In [85]:
class Decoder(nn.Module):
    def __init__(self, vocab_size, hidden_size, pwff_size, n_layers, n_heads, device, max_size=100):
        super().__init__()

        self.device = device
        self.token_embedding = nn.Embedding(vocab_size, hidden_size)
        self.position_embedding = nn.Embedding(max_size, hidden_size)
        self.layers = nn.ModuleList([DecoderLayer(hidden_size, n_heads, pwff_size,device) for _ in range(n_layers)])
        self.fc_out = nn.Linear(hidden_size, vocab_size)
        self.scale = torch.sqrt(torch.FloatTensor([hidden_size])).to(device)
    def forward(self, input, encoder_output, input_mask, encoder_mask):
        #input: [batch_size, input_size, hidden_size]
        #encoder_output = [batch_size, encoder_input_size, hidden_size]
        #input_mask: [batch_size, 1, input_size, input_size]
        #encoder_mask: [batch_size, 1, 1, encoder_input_size]

        batch_size, input_size = input.size()

        pos = torch.arange(0, input_size).unsqueeze(0).repeat(batch_size,1).to(self.device)
        #pos: [batch_size, input_size]

        out = (self.token_embedding(input) * self.scale) + self.position_embedding(pos)
        #out: [batch_size, input_size, hidden_size]

        for layer in self.layers:
            out, attention = layer(out, encoder_output, input_mask, encoder_mask)
        
        #out: [batch_size, input_size, hidden_size]
        #attention: [batch_size, n_heads, input_size, encoder_input_size]

        out = self.fc_out(out)
        #out: [batch_size, input_size, vocab_size]

        return out, attention

In [86]:
class Transformer(nn.Module):
    def __init__(self, encoder, decoder, enc_pad_idx, dec_pad_idx, device):
        super().__init__()

        self.encoder = encoder
        self.decoder = decoder
        self.enc_pad_idx = enc_pad_idx
        self.dec_pad_idx = dec_pad_idx
        self.device = device
    def make_enc_mask(self, input):
        #input: [batch_size, input_size]
        enc_mask = (input != self.enc_pad_idx).unsqueeze(1).unsqueeze(2)
        #enc_mask: [batch_size, 1, 1, input_size]

    def make_dec_mask(self, input):
        #input: [batch_size, input_size]
        dec_pad_mask = (input != self.dec_pad_idx).unsqueeze(1).unsqueeze(2)
        #dec_pad_mask: [batch_size, 1, 1, input_size]

        dec_input_size = input.size(1)
        dec_sub_mask = torch.tril(torch.ones((dec_input_size,dec_input_size), device=self.device)).bool()
        dec_mask = dec_pad_mask & dec_sub_mask

        return dec_mask

    def forward(self, src, tgt):
        #src: [batch_size, src_length]
        #tgt: [batch_size, tgt_length]

        src_mask = self.make_enc_mask(src)
        tgt_mask = self.make_dec_mask(tgt)

        enc_output = self.encoder(src, src_mask)

        output, attention = self.decoder(tgt, enc_output, tgt_mask, src_mask)

        return output, attention



In [87]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [88]:
batch_size=256
learning_rate = 0.001
num_epochs = 10

In [89]:
def pad_collate_fn(batch):
    collate_x = []
    collate_y = []
    for src, tgt in batch:
        collate_x.append(src)
        collate_y.append(tgt)
    collate_x = nn.utils.rnn.pad_sequence(collate_x, padding_value=vocab_de['<pad>'], batch_first=True)
    collate_y = nn.utils.rnn.pad_sequence(collate_y, padding_value=vocab_en['<pad>'], batch_first=True)
    return (collate_x, collate_y)

In [90]:
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn= pad_collate_fn)
valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, collate_fn= pad_collate_fn)

In [91]:
encoder = Encoder(vocab_size=len(vocab_de), hidden_size=256, pwff_size=512, n_layers=3, n_heads=8, device=device, max_size=100)
decoder = Decoder(vocab_size=len(vocab_en), hidden_size=256, pwff_size=512, n_layers=3, n_heads=8, device=device, max_size=100)

In [92]:
model = Transformer(encoder, decoder, vocab_de['<pad>'], vocab_en['<pad>'], device).to(device)

In [93]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The model has {count_parameters(model):,} trainable parameters')

The model has 14,599,302 trainable parameters


In [94]:
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss(ignore_index = vocab_en['<pad>'])

In [95]:
def train(dataloader, epoch):
    model.train()
    epoch_loss = 0
    epoch_acc = 0
    data_num = 0
    for src,tgt in tqdm(dataloader, desc=f"Epoch {epoch}"):
        src, tgt = src.to(device), tgt.to(device)
        outputs, _ = model(src, tgt)
        #outputs: [batch_size, tgt_len, output_size]
        #tgt: [batch_size, tgt_len, output_size]
        output_size = outputs.shape[-1]
        
        output = outputs.reshape(-1,output_size)
        #output: [(tgt_len - 1) * batch_size, output_size]
        tgt = tgt.reshape(-1)
        #tgt: [(tgt_len - 1) * batch_size]

        optimizer.zero_grad()
        loss = criterion(output, tgt)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        data_num += src.size(0)

    print(f"Train Epoch: {epoch}, Loss: {epoch_loss/len(dataloader)}, PPL: {math.exp(epoch_loss/len(dataloader))}")


In [96]:
def evaluate(dataloader, epoch):
    model.eval()
    epoch_loss = 0
    epoch_acc = 0
    data_num = 0
    for src,tgt in tqdm(dataloader, desc=f"Epoch {epoch}"):
        src, tgt = src.to(device), tgt.to(device)
        outputs, _ = model(src, tgt)
        #outputs: [batch_size, tgt_len, output_size]
        #tgt: [batch_size, tgt_len, output_size]
        output_size = outputs.shape[-1]
        output = outputs.reshape(-1,output_size)
        #output: [(tgt_len - 1) * batch_size, output_size]
        tgt = tgt.reshape(-1)
        #tgt: [(tgt_len - 1) * batch_size]

        loss = criterion(output, tgt)
        epoch_loss += loss.item()
        data_num += src.size(0)

    print(f"Evaluate Epoch: {epoch}, Loss: {epoch_loss/len(dataloader)}, PPL: {math.exp(epoch_loss/len(dataloader))}")


In [97]:

for epoch in range(num_epochs):
    train(train_dataloader, epoch)
    evaluate(valid_dataloader, epoch)
    print('-'*50)

Epoch 0: 100%|██████████| 91/91 [00:15<00:00,  5.92it/s]


Train Epoch: 0, Loss: 1.748643146438913, PPL: 5.746799817737287


Epoch 0: 100%|██████████| 23/23 [00:01<00:00, 16.01it/s]


Evaluate Epoch: 0, Loss: 0.5272486222826916, PPL: 1.6942643287540422
--------------------------------------------------


Epoch 1: 100%|██████████| 91/91 [00:15<00:00,  5.91it/s]


Train Epoch: 1, Loss: 0.3171060371857423, PPL: 1.3731481689734168


Epoch 1: 100%|██████████| 23/23 [00:01<00:00, 16.22it/s]


Evaluate Epoch: 1, Loss: 0.27965009471644525, PPL: 1.3226669232136692
--------------------------------------------------


Epoch 2: 100%|██████████| 91/91 [00:14<00:00,  6.07it/s]


Train Epoch: 2, Loss: 0.1468359622490275, PPL: 1.1581639648347941


Epoch 2: 100%|██████████| 23/23 [00:01<00:00, 15.22it/s]


Evaluate Epoch: 2, Loss: 0.20513340960378232, PPL: 1.2276888395200338
--------------------------------------------------


Epoch 3: 100%|██████████| 91/91 [00:15<00:00,  6.04it/s]


Train Epoch: 3, Loss: 0.07461440952105837, PPL: 1.0774686091382675


Epoch 3: 100%|██████████| 23/23 [00:01<00:00, 15.86it/s]


Evaluate Epoch: 3, Loss: 0.1716972369214763, PPL: 1.187318302645865
--------------------------------------------------


Epoch 4: 100%|██████████| 91/91 [00:15<00:00,  5.95it/s]


Train Epoch: 4, Loss: 0.0348629158113029, PPL: 1.0354777514423206


Epoch 4: 100%|██████████| 23/23 [00:01<00:00, 15.59it/s]


Evaluate Epoch: 4, Loss: 0.15637925299613373, PPL: 1.1692695679977712
--------------------------------------------------


Epoch 5: 100%|██████████| 91/91 [00:15<00:00,  5.85it/s]


Train Epoch: 5, Loss: 0.011925492640380021, PPL: 1.011996884841202


Epoch 5: 100%|██████████| 23/23 [00:01<00:00, 15.35it/s]


Evaluate Epoch: 5, Loss: 0.1501953433389249, PPL: 1.1620612214772066
--------------------------------------------------


Epoch 6: 100%|██████████| 91/91 [00:16<00:00,  5.65it/s]


Train Epoch: 6, Loss: 0.004746299472402085, PPL: 1.0047575809931732


Epoch 6: 100%|██████████| 23/23 [00:01<00:00, 12.65it/s]


Evaluate Epoch: 6, Loss: 0.14988347207722458, PPL: 1.1616988644851947
--------------------------------------------------


Epoch 7: 100%|██████████| 91/91 [00:15<00:00,  6.01it/s]


Train Epoch: 7, Loss: 0.002985106190093435, PPL: 1.002989566056197


Epoch 7: 100%|██████████| 23/23 [00:01<00:00, 14.95it/s]


Evaluate Epoch: 7, Loss: 0.14896443615788998, PPL: 1.1606317119521492
--------------------------------------------------


Epoch 8: 100%|██████████| 91/91 [00:15<00:00,  5.94it/s]


Train Epoch: 8, Loss: 0.0021932763098687917, PPL: 1.002195683299764


Epoch 8: 100%|██████████| 23/23 [00:01<00:00, 15.75it/s]


Evaluate Epoch: 8, Loss: 0.14893549875072812, PPL: 1.160598126765672
--------------------------------------------------


Epoch 9: 100%|██████████| 91/91 [00:15<00:00,  5.96it/s]


Train Epoch: 9, Loss: 0.0017178093126698673, PPL: 1.0017192855922885


Epoch 9: 100%|██████████| 23/23 [00:01<00:00, 15.24it/s]

Evaluate Epoch: 9, Loss: 0.14886186854994815, PPL: 1.1605126748385333
--------------------------------------------------



