In [1]:
import torch
import random
import math
import torch.nn as nn
import torch.nn.functional as F
from torchtext.datasets import Multi30k
from dataloader import *
from utils import *
from torch.optim import Adam
from tqdm import tqdm

In [2]:
N_EPOCHS = 10
CLIP = 1
best_valid_loss = float("inf")
emb_dim = 256
enc_hid_dim = dec_hid_dim = 512
n_layers = 10
dropout = 0.25
batch_size = 128
kernel_size = 3
device = torch.device("cuda" if torch.cuda.is_available else "cpu")

In [3]:
train_dataset, val_dataset, test_dataset = Multi30k(root="data")
train_dataset, val_dataset, test_dataset = (
    to_map_style_dataset(train_dataset),
    to_map_style_dataset(val_dataset),
    to_map_style_dataset(test_dataset),
)
train_dataloader, val_dataloader, test_dataloader, etc = get_dataloader_and_etc(
    train_dataset,
    val_dataset,
    test_dataset,
    batch_size,
)
_, _, vocab_de, vocab_en = etc
input_dim = len(vocab_de)
output_dim = len(vocab_en)

In [4]:
class Encoder(nn.Module):
    def __init__(
        self,
        input_dim,
        emb_dim,
        hid_dim,
        n_layers,
        kernel_size,
        dropout,
        max_length=100,
    ):
        super().__init__()
        self.tok_embedding = nn.Embedding(input_dim, emb_dim)
        self.pos_embedding = nn.Embedding(max_length, emb_dim)
        self.emb2hid = nn.Linear(emb_dim, hid_dim)
        self.dropout = nn.Dropout(dropout)
        self.convs = nn.ModuleList(
            [
                nn.Conv1d(
                    in_channels=hid_dim,
                    out_channels=2 * hid_dim,
                    kernel_size=kernel_size,
                    padding=(kernel_size - 1) // 2,
                )
                for _ in range(n_layers)
            ]
        )
        self.hid2emb = nn.Linear(hid_dim, emb_dim)

    def forward(self, src, pos, scale):
        # src [bs, src_len]
        batch_size, src_len = src.shape
        tok_embedded = self.tok_embedding(src)
        # tok_embedded [bs, src_len, emb_dim]
        pos_embedded = self.pos_embedding(pos)
        # pos_embedded [bs, src_len, emb_dim]
        embedded = self.dropout(tok_embedded + pos_embedded)
        # embedded [bs, src_len, emb_dim]
        conv_input = self.emb2hid(embedded)
        # conv_input [bs, src_en, hid_dim]
        conv_input = conv_input.permute(0, 2, 1)
        # conv_input [bs, hid_dim, src_len]
        for conv in self.convs:
            conved = self.dropout(conv_input)
            # conved [bs, hid_dim, src_len]
            conved = conv(conved)
            # conved [bs, hid_dim*2, src_len]
            conved = F.glu(conved, dim=1)
            # conved [bs, hid_dim, src_len]
            conved = (conved + conv_input) * scale
            # conved [bs, hid_dim, src_len]
            conv_input = conved
        conved = conved.permute(0, 2, 1)
        # conved [bs, src_len, hid_dim]
        conved = self.hid2emb(conved)
        # conved [bs, src_len, emb_dim]
        combined = (conved + embedded) * scale
        # combined [bs, src_len, emb_dim]
        conved = conved.permute(0, 2, 1)
        # conved [bs, emb_dim, src_len]
        return conved, combined

In [5]:
class Decoder(nn.Module):
    def __init__(
        self,
        output_dim,
        emb_dim,
        hid_dim,
        n_layers,
        kernel_size,
        dropout,
        max_length=100,
    ):
        super().__init__()
        self.tok_embedding = nn.Embedding(output_dim, emb_dim)
        self.pos_embedding = nn.Embedding(max_length, emb_dim)
        self.emb2hid = nn.Linear(emb_dim, hid_dim)
        self.attn_hid2emb = nn.Linear(hid_dim, emb_dim)
        self.attn_emb2hid = nn.Linear(emb_dim, hid_dim)
        self.convs = nn.ModuleList(
            [
                nn.Conv1d(
                    in_channels=hid_dim,
                    out_channels=2 * hid_dim,
                    kernel_size=kernel_size,
                )
                for _ in range(n_layers)
            ]
        )
        self.hid2emb = nn.Linear(hid_dim, emb_dim)
        self.fc_out = nn.Linear(emb_dim, output_dim)
        self.dropout = nn.Dropout(dropout)

    def calculate_attention(
        self, embedded, conved, encoder_conved, encoder_combined, scale
    ):
        # conved [bs, trg_len, hid_dim]
        conved_emb = self.attn_hid2emb(conved)
        # conved_emb [bs, trg_len, emb_dim]
        combined = (conved_emb + embedded) * scale
        # combined [bs, trg_len, emb_dim]
        # encoder_conved [bs, emb_dim, src_len]
        energy = torch.matmul(combined, encoder_conved)
        # energy [bs, trg_len, src_len]
        attention = F.softmax(energy, dim=2)
        # attention [bs, trg_len, src_len]
        # encoder_combined [bs, src_len, emb_dim]
        attended_encoding = torch.matmul(attention, encoder_combined)
        # attended_encoding [bs, trg_len, emb_dim]
        attended_encoding = self.attn_emb2hid(attended_encoding)
        # attended_encoding [bs, trg_len, hid_dim]
        attended_combined = (conved + attended_encoding) * scale
        # attended_encoding [bs, trg_len, hid_dim]
        attended_combined = attended_combined.permute(0, 2, 1)
        # attended_encoding [bs, hid_dim, trg_len]
        return attention, attended_combined

    def forward(self, trg, encoder_conved, encoder_combined, pos, padding, scale):
        tok_embedded = self.tok_embedding(trg)
        # tok_embedded [bs, trg_len, emb_dim]
        pos_embedded = self.pos_embedding(pos)
        # tok_embedded [bs, trg_len, emb_dim]
        embedded = self.dropout(tok_embedded + pos_embedded)
        # embedded [bs, trg_len, emb_dim]
        conv_input = self.emb2hid(embedded)
        # conv_input [bs, trg_len, hid_dim]
        conv_input = conv_input.permute(0, 2, 1)
        # conv_input [bs, hid_dim, trg_len]
        for conv in self.convs:
            conv_input = self.dropout(conv_input)
            # padding [bs, hid_dim, kernel_size-1]
            padded_conv_input = torch.cat((padding, conv_input), dim=2)
            # padded_conv_input [bs, hid_dim, trg_len+kernel_size-1]
            conved = conv(padded_conv_input)
            # conved [bs, hid_dim*2, trg_len]
            conved = F.glu(conved, dim=1)
            # conved [bs, hid_dim, trg_len]
            conved = conved.permute(0, 2, 1)
            # conved [bs, trg_len, hid_dim]
            attention, conved = self.calculate_attention(
                embedded, conved, encoder_conved, encoder_combined, scale
            )
            # conved [bs, hid_dim, trg_len)
            conv_input = (conved + conv_input) * scale
        conved = conv_input.permute(0, 2, 1)
        conved = self.dropout(conved)
        # conved [bs, trg_len, hid_dim]
        conved = self.hid2emb(conved)
        # conved [bs, trg_len, emb_dim]
        output = self.fc_out(conved)
        # output [bs, trg_len, output_dim]
        return output, attention

In [6]:
class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, device):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device = device
        self.scale = torch.sqrt(torch.FloatTensor([0.5])).to(device)

    def forward(self, src, trg):
        src = src.permute(1, 0)
        trg = trg.permute(1, 0)
        batch_size, src_len = src.shape
        _, trg_len = trg.shape
        src_pos = (
            torch.arange(0, src_len)[None, :].repeat(batch_size, 1).to(self.device)
        )
        # src_pos [bs, src_len]
        encoder_conved, encoder_combined = self.encoder(src, src_pos, self.scale)
        trg_pos = (
            torch.arange(0, trg_len)[None, :].repeat(batch_size, 1).to(self.device)
        )
        # src_pos [bs, src_len]
        padding = torch.zeros(batch_size, dec_hid_dim, kernel_size - 1).to(self.device)
        # padding [bs, hid_dim, kernel_size-1]
        output, attention = self.decoder(
            trg, encoder_conved, encoder_combined, trg_pos, padding, self.scale
        )
        return output, attention

In [7]:
enc = Encoder(input_dim, emb_dim, enc_hid_dim, n_layers, kernel_size, dropout)
dec = Decoder(output_dim, emb_dim, dec_hid_dim, n_layers, kernel_size, dropout)
model = Seq2Seq(enc, dec, device).to(device)
optimizer = Adam(model.parameters())
criterion = nn.CrossEntropyLoss(ignore_index=0)
count_parameters(model)
model.apply(init_weights3)

The model has 37,351,173 trainable parameters


Seq2Seq(
  (encoder): Encoder(
    (tok_embedding): Embedding(7853, 256)
    (pos_embedding): Embedding(100, 256)
    (emb2hid): Linear(in_features=256, out_features=512, bias=True)
    (dropout): Dropout(p=0.25, inplace=False)
    (convs): ModuleList(
      (0-9): 10 x Conv1d(512, 1024, kernel_size=(3,), stride=(1,), padding=(1,))
    )
    (hid2emb): Linear(in_features=512, out_features=256, bias=True)
  )
  (decoder): Decoder(
    (tok_embedding): Embedding(5893, 256)
    (pos_embedding): Embedding(100, 256)
    (emb2hid): Linear(in_features=256, out_features=512, bias=True)
    (attn_hid2emb): Linear(in_features=512, out_features=256, bias=True)
    (attn_emb2hid): Linear(in_features=256, out_features=512, bias=True)
    (convs): ModuleList(
      (0-9): 10 x Conv1d(512, 1024, kernel_size=(3,), stride=(1,))
    )
    (hid2emb): Linear(in_features=512, out_features=256, bias=True)
    (fc_out): Linear(in_features=256, out_features=5893, bias=True)
    (dropout): Dropout(p=0.25, inpl

In [8]:
t_batch = math.ceil(len(train_dataset) // batch_size)
v_batch = math.ceil(len(val_dataset) // batch_size)
for epoch in range(N_EPOCHS):
    train(
        epoch,
        model,
        train_dataloader,
        t_batch,
        optimizer,
        criterion,
        CLIP,
        device,
        mode="cnn",
    )
    eval_loss = evaluate(model, val_dataloader, v_batch, criterion, device, mode='cnn')
    if eval_loss < best_valid_loss:
        best_valid_loss = eval_loss
        torch.save(model.state_dict(), "weight/tut5-model.pt")

Epoch: 1: 100%|██████████| 226/226 [00:13<00:00, 17.34it/s, train_loss=5.65]
100%|██████████| 7/7 [00:00<00:00, 63.80it/s, eval_loss=4.47]
Epoch: 2: 100%|██████████| 226/226 [00:11<00:00, 20.49it/s, train_loss=3.94]
100%|██████████| 7/7 [00:00<00:00, 62.43it/s, eval_loss=3.7] 
Epoch: 3: 100%|██████████| 226/226 [00:10<00:00, 20.58it/s, train_loss=3.34]
100%|██████████| 7/7 [00:00<00:00, 65.33it/s, eval_loss=3.15]
Epoch: 4: 100%|██████████| 226/226 [00:10<00:00, 20.59it/s, train_loss=2.87]
100%|██████████| 7/7 [00:00<00:00, 70.84it/s, eval_loss=2.69]
Epoch: 5: 100%|██████████| 226/226 [00:10<00:00, 20.62it/s, train_loss=2.47]
100%|██████████| 7/7 [00:00<00:00, 70.18it/s, eval_loss=2.35]
Epoch: 6: 100%|██████████| 226/226 [00:11<00:00, 20.51it/s, train_loss=2.2] 
100%|██████████| 7/7 [00:00<00:00, 65.57it/s, eval_loss=2.19]
Epoch: 7: 100%|██████████| 226/226 [00:10<00:00, 20.56it/s, train_loss=2.01]
100%|██████████| 7/7 [00:00<00:00, 69.67it/s, eval_loss=2.11]
Epoch: 8: 100%|██████████| 

In [9]:
model.load_state_dict(torch.load("weight/tut5-model.pt"))
t_batch = math.ceil(len(test_dataset) // batch_size)
evaluate(model, test_dataloader, t_batch, criterion, device, mode='cnn')

100%|██████████| 7/7 [00:00<00:00, 22.42it/s, eval_loss=1.95]


1.9496896437236242