In [49]:
import torch
import random
import math
import torch.nn as nn
import torch.nn.functional as F
from torchtext.datasets import Multi30k
from dataloader import *
from utils import *
from torch.optim import Adam
from tqdm import tqdm

In [50]:
N_EPOCHS = 10
CLIP = 1
best_valid_loss = float('inf')
emb_dim = 256
enc_hid_dim = dec_hid_dim = 512
n_layers = 2
dropout = 0.5
batch_size = 128

In [51]:
train_dataset, val_dataset, test_dataset = Multi30k(root='data')
train_dataset, val_dataset, test_dataset = to_map_style_dataset(train_dataset),to_map_style_dataset(val_dataset), to_map_style_dataset(test_dataset)
train_dataloader, val_dataloader, test_dataloader, etc = get_dataloader_and_etc(train_dataset, val_dataset, test_dataset,batch_size)
_, _, vocab_de, vocab_en = etc
input_dim = len(vocab_de)
output_dim = len(vocab_en)

In [52]:
class Encoder(nn.Module):
    def __init__(self, input_dim, emb_dim, enc_hid_dim, dec_hid_dim, dropout):
        super().__init__()
        self.embedding = nn.Embedding(input_dim, emb_dim)
        self.dropout = nn.Dropout(dropout)
        self.bigru = nn.GRU(emb_dim, enc_hid_dim, batch_first=True, bidirectional = True)
        self.fc = nn.Linear(enc_hid_dim * 2, dec_hid_dim)

    def forward(self, src):
        # src [src_len, bs]
        src = src.permute(1,0)
        # src [bs, src_len]
        embedded = self.embedding(src)
        embedded = self.dropout(embedded)
        # embedded [bs, src_len, emb_dim]
        outputs, hidden = self.bigru(embedded)
        # outputs [bs, src_len, enc_hid_dim*2]
        # hidden [2, bs, hid_dim]
        hidden = torch.cat((hidden[0,:,:],hidden[1,:,:]), dim=1)
        # hidden [bs, enc_hid_dim*2]
        hidden = self.fc(hidden)
        hidden = torch.tanh(hidden)
        # hidden [bs, dec_hid_dim]
        hidden = hidden[None,:,:]
        # hidden [1, bs, dec_hid_dim]
        return outputs, hidden

In [53]:
class Decoder(nn.Module):
    def __init__(self, output_dim, emb_dim, enc_hid_dim, dec_hid_dim, dropout):
        super().__init__()
        self.embedding = nn.Embedding(output_dim, emb_dim)
        self.dropout = nn.Dropout(dropout)
        self.attn = nn.Linear(enc_hid_dim * 2 + dec_hid_dim, dec_hid_dim)
        self.v = nn.Linear(dec_hid_dim, 1, bias = False)
        self.gru = nn.GRU((enc_hid_dim * 2) + emb_dim, dec_hid_dim)
        self.fc_out = nn.Linear((enc_hid_dim * 2) + dec_hid_dim + emb_dim, output_dim)
        self.output_dim = output_dim

    def calculate_attn(self, hidden, encoder_outputs):
        # hidden [1, bs, dec_hid_dim]
        # encoder_outputs [bs, src_len, enc_hid_dim*2]
        src_len = encoder_outputs.shape[1]
        hidden = hidden.permute(1,0,2)
        # hidden [bs, 1, dec_hid_dim]
        hidden = hidden.repeat(1, src_len, 1)
        # hidden [bs, src_len, dec_hid_dim]
        hid_con = torch.cat((hidden, encoder_outputs), dim=2)
        # hid_cat [bs, src_len, enc_hid_dim*2+dec_hid_dim]
        attn = self.attn(hid_con)
        attn = torch.tanh(attn)
        # attn [bs, src_len, dec_hid_dim]
        attn = self.v(attn).squeeze(2)
        # attn [bs, src_len]
        return F.softmax(attn, dim=1)

    def forward(self, input, hidden, encoder_outputs):
        # input [1, bs]
        embedded = self.embedding(input)
        embedded = self.dropout(embedded)
        # embedded [1, bs, emb_dim]
        # encoder_outputs [bs, src_len, enc_hid_dim*2]
        # hidden []
        attn = self.calculate_attn(hidden, encoder_outputs)
        # attn [bs, src_len]
        attn = attn[:,None,:]
        # attn [bs, 1, src_len]
        context = torch.bmm(attn, encoder_outputs)
        # context [bs, 1, enc_hid_dim*2]
        context = context.permute(1,0,2)
        # context [1, bs, enc_hid_dim*2]
        emb_con = torch.cat((embedded, context), dim=2)
        # emb_con [1, bs, enc_hid_dim*2+emb_dim]
        output, hidden = self.gru(emb_con, hidden)
        # output [1, bs, dec_hid_dim]
        # hidden [1, bs, dec_hid_dim]
        embedded = embedded.squeeze(0)
        # embedded [bs, emb_dim]
        context = context.squeeze(0)
        # context [bs, enc_hid_dim*2]
        output = output.squeeze(0)
        # output [bs, dec_hid_dim]
        output_con = torch.cat((embedded, context, output), dim=1)
        # output_con [bs, emb_dim+enc_dim*2+dec_dim]
        prediction = self.fc_out(output_con)
        # prediction [bs, output_dim]
        return prediction, hidden


In [54]:
class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, device):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device = device
    
    def forward(self, src, trg, teacher_forcing_ratio = 0.5):
        trg_len, batch_size = trg.shape
        output_dim = self.decoder.output_dim
        outputs = torch.zeros(trg_len-1, batch_size, output_dim).to(self.device)
        # outputs [trg_len-1, bs, output_dim]
        # src [src_len, bs]
        encoder_outputs, hidden = self.encoder(src)
        # encoder_outputs [bs, src_len, enc_hid_dim*2]
        # hidden [bs, dec_hid_dim]
        # trg [trg_len, bs]
        top1 = trg[0, None]
        for t in range(0, trg_len - 1):
            input = trg[t, None] if random.random()<teacher_forcing_ratio else  top1
            # input [1, bs]
            output, hidden = self.decoder(input, hidden, encoder_outputs)
            # output [bs, output_dim]
            outputs[t] = output
            top1 = output.argmax(1)[None,:]
        return outputs

In [55]:
device = torch.device('cuda' if torch.cuda.is_available else 'cpu')
enc = Encoder(input_dim, emb_dim, enc_hid_dim, dec_hid_dim, dropout)
dec = Decoder(output_dim, emb_dim, enc_hid_dim, dec_hid_dim, dropout)
model = Seq2Seq(enc, dec, device).to(device)
optimizer = Adam(model.parameters())
criterion = nn.CrossEntropyLoss(ignore_index=0)
count_parameters(model)
model.apply(init_weights3)

The model has 20,518,405trainable parameters


Seq2Seq(
  (encoder): Encoder(
    (embedding): Embedding(7853, 256)
    (dropout): Dropout(p=0.5, inplace=False)
    (bigru): GRU(256, 512, batch_first=True, bidirectional=True)
    (fc): Linear(in_features=1024, out_features=512, bias=True)
  )
  (decoder): Decoder(
    (embedding): Embedding(5893, 256)
    (dropout): Dropout(p=0.5, inplace=False)
    (attn): Linear(in_features=1536, out_features=512, bias=True)
    (v): Linear(in_features=512, out_features=1, bias=False)
    (gru): GRU(1280, 512)
    (fc_out): Linear(in_features=1792, out_features=5893, bias=True)
  )
)

In [56]:
t_batch = math.ceil(len(train_dataset)//batch_size)
v_batch = math.ceil(len(val_dataset)//batch_size)
for epoch in range(N_EPOCHS):
    train(epoch, model, train_dataloader, t_batch, optimizer, criterion, CLIP, device)
    eval_loss = evaluate(model, val_dataloader, v_batch, criterion, device)
    if eval_loss < best_valid_loss:
        best_valid_loss = eval_loss
        torch.save(model.state_dict(), 'weight/tut3-model.pt')

Epoch: 1: 100%|██████████| 226/226 [00:17<00:00, 13.23it/s, train_loss=4.99]
100%|██████████| 7/7 [00:00<00:00, 37.83it/s, eval_loss=4.76]
Epoch: 2: 100%|██████████| 226/226 [00:16<00:00, 13.64it/s, train_loss=4.08]
100%|██████████| 7/7 [00:00<00:00, 38.78it/s, eval_loss=3.78]
Epoch: 3: 100%|██████████| 226/226 [00:16<00:00, 13.72it/s, train_loss=3.35]
100%|██████████| 7/7 [00:00<00:00, 40.15it/s, eval_loss=3.09]
Epoch: 4: 100%|██████████| 226/226 [00:16<00:00, 13.67it/s, train_loss=2.81]
100%|██████████| 7/7 [00:00<00:00, 36.35it/s, eval_loss=2.68]
Epoch: 5: 100%|██████████| 226/226 [00:16<00:00, 13.52it/s, train_loss=2.45]
100%|██████████| 7/7 [00:00<00:00, 36.70it/s, eval_loss=2.51]
Epoch: 6:  19%|█▉        | 43/226 [00:02<00:11, 16.38it/s, train_loss=2.01]

In [None]:
model.load_state_dict(torch.load('tut3-model.pt'))
t_batch = math.ceil(len(test_dataset)//batch_size)
evaluate(model, test_dataloader, t_batch, criterion, device)