In [1]:
import torch
import math
import torch.nn as nn
import torch.nn.functional as F
from torchtext.datasets import Multi30k
from dataloader import *
from utils import *
from torch.optim import Adam

In [2]:
N_EPOCHS = 10
CLIP = 1
best_valid_loss = float("inf")
enc_hid_dim = dec_hid_dim = 256
enc_pf_dim = dec_pf_dim = 512
n_layers = 3
enc_heads = dec_heads = 8
dropout = 0.1
batch_size = 128
device = torch.device("cuda" if torch.cuda.is_available else "cpu")
scale = torch.sqrt(torch.FloatTensor([0.5])).to(device)

In [3]:
train_dataset, val_dataset, test_dataset = Multi30k(root="data")
train_dataset, val_dataset, test_dataset = (
    to_map_style_dataset(train_dataset),
    to_map_style_dataset(val_dataset),
    to_map_style_dataset(test_dataset),
)
train_dataloader, val_dataloader, test_dataloader, etc = get_dataloader_and_etc(
    train_dataset,
    val_dataset,
    test_dataset,
    batch_size,
)
_, _, vocab_de, vocab_en = etc
input_dim = len(vocab_de)
output_dim = len(vocab_en)

In [4]:
class MultiHeadAttentionLayer(nn.Module):
    def __init__(self, hid_dim, n_heads, dropout):
        super().__init__()
        self.fc_q = nn.Linear(hid_dim, hid_dim)
        self.fc_k = nn.Linear(hid_dim, hid_dim)
        self.fc_v = nn.Linear(hid_dim, hid_dim)
        self.fc_o = nn.Linear(hid_dim, hid_dim)
        self.droput = nn.Dropout(dropout)
        self.softmax = nn.Softmax(dim=-1)
        self.hid_dim = hid_dim
        self.n_heads = n_heads
        self.head_dim = hid_dim // n_heads

    def forward(self, query, key, value, mask=None):
        batch_size = query.shape[0]
        Q = self.fc_q(query)
        # Q [bs, q_len, hid_dim]
        K = self.fc_k(key)
        # K [bs, k_len, hid_dim]
        V = self.fc_v(value)
        # V [bs, v_len, hid_dim]
        Q = Q.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
        # Q [bs, n_heads, q_len, head_dim]
        K = K.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
        # K [bs, n_heads, k_len, head_dim]
        V = V.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
        # V [bs, n_heads, v_len, head_dim]
        # K.permute [bs, n_heads, head_dim, k_len]
        energy = torch.matmul(Q, K.permute(0, 1, 3, 2)) / scale
        # energy [bs, n_heads, q_len, k_len]
        if mask is not None:
            energy = energy.masked_fill_(mask, -1e10)
        attention = self.softmax(energy)
        attention_d = self.droput(attention)
        # attention [bs, n_heads, q_len, k_len]
        out = torch.matmul(attention_d, V)
        # out [bs, n_heads, q_len, head_dim]
        out = out.permute(0, 2, 1, 3)
        out = out.reshape(batch_size, -1, self.hid_dim)
        # out [bs, q_len, hid_dim]
        out = self.fc_o(out)
        out = self.droput(out)
        # out [bs, q_len, hid_dim]
        return out, attention

In [5]:
class PositionwiseFeedforwardLayer(nn.Module):
    def __init__(self, hid_dim, pf_dim, dropout):
        super().__init__()
        self.hid2pf = nn.Linear(hid_dim, pf_dim)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout)
        self.pf2hid = nn.Linear(pf_dim, hid_dim)

    def forward(self, x):
        # x [bs, seq_len, hid_dim]
        out = self.hid2pf(x)
        # out [bsm seq_len, pf_dim]
        out = self.relu(out)
        out = self.dropout(out)
        out = self.pf2hid(out)
        out = self.dropout(out)
        # out [bsm seq_len, hid_dim]
        return out

In [6]:
class EncoderLayer(nn.Module):
    def __init__(self, hid_dim, n_heads, pf_dim, dropout):
        super().__init__()
        self.self_attention = MultiHeadAttentionLayer(hid_dim, n_heads, dropout)
        self.self_attn_layer_norm = nn.LayerNorm(hid_dim)
        self.positionwise_feedforward = PositionwiseFeedforwardLayer(
            hid_dim, pf_dim, dropout
        )
        self.ff_layer_norm = nn.LayerNorm(hid_dim)

    def forward(self, src, src_mask):
        _src, _ = self.self_attention(src, src, src, src_mask)
        # _src [bs, src_len, hid_dim]
        src_res = src + _src
        src = self.self_attn_layer_norm(src_res)
        _src = self.positionwise_feedforward(src)
        src_res = src + _src
        src = self.ff_layer_norm(src_res)
        # src [bs, src_len, hid_dim]
        return src

In [7]:
class Encoder(nn.Module):
    def __init__(
        self,
        input_dim,
        hid_dim,
        n_layers,
        n_heads,
        pf_dim,
        dropout,
        max_length=100,
    ):
        super().__init__()
        self.tok_embedding = nn.Embedding(input_dim, hid_dim)
        self.pos_embedding = nn.Embedding(max_length, hid_dim)
        self.layers = nn.ModuleList(
            [EncoderLayer(hid_dim, n_heads, pf_dim, dropout) for _ in range(n_layers)]
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, src, pos, mask):
        tok_embedded = self.tok_embedding(src)
        pos_embedded = self.pos_embedding(pos)
        embedded = tok_embedded * scale + pos_embedded
        src_hid = self.dropout(embedded)
        # src_hid [bs, trg_len, hid_dim]
        for layer in self.layers:
            src_hid = layer(src_hid, mask)
        return src_hid

In [8]:
class DecoderLayer(nn.Module):
    def __init__(
        self,
        hid_dim,
        n_heads,
        pf_dim,
        dropout,
    ):
        super().__init__()
        self.self_attention = MultiHeadAttentionLayer(hid_dim, n_heads, dropout)
        self.self_attn_layer_norm = nn.LayerNorm(hid_dim)
        self.encoder_attention = MultiHeadAttentionLayer(hid_dim, n_heads, dropout)
        self.enc_attn_layer_norm = nn.LayerNorm(hid_dim)
        self.positionwise_feedforward = PositionwiseFeedforwardLayer(
            hid_dim, pf_dim, dropout
        )
        self.ff_layer_norm = nn.LayerNorm(hid_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, trg, enc_src, trg_mask, src_mask):
        _trg, _ = self.self_attention(trg, trg, trg, trg_mask)
        trg_res = trg + _trg
        trg = self.self_attn_layer_norm(trg_res)

        _trg, attention = self.encoder_attention(trg, enc_src, enc_src, src_mask)
        trg_res = trg + _trg
        trg = self.enc_attn_layer_norm(trg_res)

        _trg = self.positionwise_feedforward(trg)
        trg_res = trg + _trg
        trg = self.ff_layer_norm(trg_res)

        return trg, attention

In [9]:
class Decoder(nn.Module):
    def __init__(
        self, output_dim, hid_dim, n_layers, n_heads, pf_dim, dropout, max_length=100
    ):
        super().__init__()
        self.tok_embedding = nn.Embedding(output_dim, hid_dim)
        self.pos_embedding = nn.Embedding(max_length, hid_dim)
        self.layers = nn.ModuleList(
            [DecoderLayer(hid_dim, n_heads, pf_dim, dropout) for _ in range(n_layers)]
        )
        self.fc_out = nn.Linear(hid_dim, output_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, trg, pos, enc_src, trg_mask, src_mask):
        tok_embedded = self.tok_embedding(trg)
        pos_embedded = self.pos_embedding(pos)
        embedded = tok_embedded * scale + pos_embedded
        trg_hid = self.dropout(embedded)
        # trg_hid [bs, trg_len, hid_dim]
        for layer in self.layers:
            trg_hid, attention = layer(trg_hid, enc_src, trg_mask, src_mask)
        output = self.fc_out(trg_hid)
        return output, attention

In [10]:
class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, device):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device = device

    def make_src_mask(self, src):
        src_mask = (src == 0)[:, None, None, :]
        return src_mask

    def make_trg_mask(self, trg):
        mask = (trg == 0)[:, None, None, :]
        # mask [bs, 1, 1, trg_len]
        trg_len = trg.shape[1]
        sub_mask = torch.triu(
            torch.ones((trg_len, trg_len), device=self.device), 1
        ).bool()
        # sub_mask [trg_len, trg_len]
        mask = mask | sub_mask
        # mask [bs, 1, trg_len, trg_len]
        return mask

    def forward(self, src, trg):
        src = src.permute(1, 0)
        trg = trg.permute(1, 0)
        batch_size, src_len = src.shape
        _, trg_len = trg.shape
        src_mask = self.make_src_mask(src)
        # src_mask [bs, 1, 1, src_len]
        trg_mask = self.make_trg_mask(trg)
        # [bs, 1, trg_len, trg_len]
        src_pos = (
            torch.arange(0, src_len)[None, :].repeat(batch_size, 1).to(self.device)
        )
        # src_pos [bs, src_len]
        enc_src = self.encoder(src, src_pos, src_mask)
        # enc_src [bs, src_len, hid_dim]
        trg_pos = (
            torch.arange(0, trg_len)[None, :].repeat(batch_size, 1).to(self.device)
        )
        output, attention = self.decoder(trg, trg_pos, enc_src, trg_mask, src_mask)
        return output, attention

In [11]:
enc = Encoder(input_dim, enc_hid_dim, n_layers, enc_heads, enc_pf_dim, dropout)
dec = Decoder(output_dim, dec_hid_dim, n_layers, dec_heads, dec_pf_dim, dropout)
model = Seq2Seq(enc, dec, device).to(device)
optimizer = Adam(model.parameters(),0.0005)
criterion = nn.CrossEntropyLoss(ignore_index=0)
count_parameters(model)
model.apply(init_weights4)

The model has 9,038,341 trainable parameters


Seq2Seq(
  (encoder): Encoder(
    (tok_embedding): Embedding(7853, 256)
    (pos_embedding): Embedding(100, 256)
    (layers): ModuleList(
      (0-2): 3 x EncoderLayer(
        (self_attention): MultiHeadAttentionLayer(
          (fc_q): Linear(in_features=256, out_features=256, bias=True)
          (fc_k): Linear(in_features=256, out_features=256, bias=True)
          (fc_v): Linear(in_features=256, out_features=256, bias=True)
          (fc_o): Linear(in_features=256, out_features=256, bias=True)
          (droput): Dropout(p=0.1, inplace=False)
          (softmax): Softmax(dim=-1)
        )
        (self_attn_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (positionwise_feedforward): PositionwiseFeedforwardLayer(
          (hid2pf): Linear(in_features=256, out_features=512, bias=True)
          (relu): ReLU()
          (dropout): Dropout(p=0.1, inplace=False)
          (pf2hid): Linear(in_features=512, out_features=256, bias=True)
        )
        (ff_l

In [12]:
t_batch = math.ceil(len(train_dataset) // batch_size)
v_batch = math.ceil(len(val_dataset) // batch_size)
for epoch in range(N_EPOCHS):
    train(
        epoch,
        model,
        train_dataloader,
        t_batch,
        optimizer,
        criterion,
        CLIP,
        device,
        mode="cnn",
    )
    eval_loss = evaluate(model, val_dataloader, v_batch, criterion, device, mode="cnn")
    if eval_loss < best_valid_loss:
        best_valid_loss = eval_loss
        torch.save(model.state_dict(), "weight/tut6-model.pt")

Epoch: 1: 100%|██████████| 226/226 [00:06<00:00, 37.63it/s, train_loss=5.06]
100%|██████████| 7/7 [00:00<00:00, 116.00it/s, eval_loss=4.73]
Epoch: 2: 100%|██████████| 226/226 [00:04<00:00, 48.24it/s, train_loss=4.19]
100%|██████████| 7/7 [00:00<00:00, 95.89it/s, eval_loss=4.14]
Epoch: 3: 100%|██████████| 226/226 [00:04<00:00, 47.89it/s, train_loss=3.92]
100%|██████████| 7/7 [00:00<00:00, 114.15it/s, eval_loss=3.87]
Epoch: 4: 100%|██████████| 226/226 [00:04<00:00, 48.46it/s, train_loss=3.76]
100%|██████████| 7/7 [00:00<00:00, 111.42it/s, eval_loss=3.8]
Epoch: 5: 100%|██████████| 226/226 [00:04<00:00, 48.42it/s, train_loss=3.65]
100%|██████████| 7/7 [00:00<00:00, 118.04it/s, eval_loss=3.58]
Epoch: 6: 100%|██████████| 226/226 [00:04<00:00, 48.23it/s, train_loss=3.56]
100%|██████████| 7/7 [00:00<00:00, 121.89it/s, eval_loss=3.54]
Epoch: 7: 100%|██████████| 226/226 [00:04<00:00, 48.25it/s, train_loss=3.47]
100%|██████████| 7/7 [00:00<00:00, 119.86it/s, eval_loss=3.42]
Epoch: 8: 100%|███████