In [23]:
import torch
import random
import math
import torch.nn as nn
from torchtext.datasets import Multi30k
from dataloader import *
from utils import *
from torch.optim import Adam
from tqdm import tqdm

In [24]:
N_EPOCHS = 10
CLIP = 1
best_valid_loss = float('inf')
emb_dim = 256
hid_dim = 512
n_layers = 2
dropout = 0.5
batch_size = 128

In [25]:
train_dataset, val_dataset, test_dataset = Multi30k(root='data')
train_dataloader, val_dataloader, test_dataloader, etc = get_dataloader_and_etc(train_dataset, val_dataset, test_dataset,batch_size)
_, _, vocab_de, vocab_en = etc
input_dim = len(vocab_de)
output_dim = len(vocab_en)

In [26]:
class Encoder(nn.Module):
    def __init__(self, input_dim, emb_dim, hid_dim, n_layers, dropout):
        super().__init__()
        self.embedding = nn.Embedding(input_dim, emb_dim)
        self.dropout = nn.Dropout(dropout)
        self.lstm = nn.LSTM(emb_dim, hid_dim, n_layers, dropout=dropout)

    def forward(self, src):
        # src [src_len, bs]
        embedded = self.embedding(src)
        embedded = self.dropout(embedded)
        # embedded [src_len, bs, emb_dim]
        _, (hidden, cell) = self.lstm(embedded)
        # hidden [n_layers, bs, hid_dim]
        # cell [n_layers, bs, hid_dim]
        return hidden, cell

In [27]:
class Decoder(nn.Module):
    def __init__(self, output_dim, emb_dim, hid_dim, n_layers, dropout):
        super().__init__()
        self.output_dim = output_dim
        self.embedding = nn.Embedding(output_dim, emb_dim)
        self.dropout = nn.Dropout(dropout)
        self.lstm = nn.LSTM(emb_dim, hid_dim, n_layers, dropout=dropout)
        self.fc_out = nn.Linear(hid_dim, output_dim)

    def forward(self, input, hidden, cell):
        # input [1, bs]
        embedded = self.embedding(input)
        embedded = self.dropout(embedded)
        # embedded [1, bs, emb_dim]
        output, (hidden, cell) = self.lstm(embedded, (hidden, cell))
        # output [1, bs, hid_dim]
        # hidden [n_layers, bs, hid_dm]
        # cell [n_layer, bs, hid_dim]
        output = output.squeeze(0)
        # output [bs, hid_dim]
        prediciton = self.fc_out(output)
        # prediction [bs, output_dim]
        return prediciton, hidden, cell

In [28]:
class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, device):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device = device
    
    def forward(self, src, trg, teacher_forcing_ratio = 0.5):
        trg_len, batch_size = trg.shape
        output_dim = self.decoder.output_dim
        outputs = torch.zeros(trg_len-1, batch_size, output_dim).to(self.device)
        # outputs [trg_len-1, bs, output_dim]
        # src [src_len, bs]
        hidden, cell = self.encoder(src)
        # hidden [n_layers, bs, hid_dim]
        # cell [n_layers, bs, hid_dim]
        # trg [trg_len, bs]
        top1 = trg[0, None]
        for t in range(0, trg_len - 1):
            input = trg[t, None] if random.random()<teacher_forcing_ratio else  top1
            # input [1, bs]
            output, hidden, cell = self.decoder(input, hidden, cell)
            # output [bs, output_dim]
            outputs[t] = output
            top1 = output.argmax(1)[None,:]
        return outputs

In [29]:
device = torch.device('cuda' if torch.cuda.is_available else 'cpu')
enc = Encoder(input_dim, emb_dim, hid_dim, n_layers, dropout).to(device)
dec = Decoder(output_dim, emb_dim, hid_dim, n_layers, dropout).to(device)
model = Seq2Seq(enc, dec, device).to(device)

In [30]:
count_parameters(model)
model.apply(init_weights)

The model has 13,898,501trainable parameters


Seq2Seq(
  (encoder): Encoder(
    (embedding): Embedding(7853, 256)
    (dropout): Dropout(p=0.5, inplace=False)
    (lstm): LSTM(256, 512, num_layers=2, dropout=0.5)
  )
  (decoder): Decoder(
    (embedding): Embedding(5893, 256)
    (dropout): Dropout(p=0.5, inplace=False)
    (lstm): LSTM(256, 512, num_layers=2, dropout=0.5)
    (fc_out): Linear(in_features=512, out_features=5893, bias=True)
  )
)

In [31]:
optimizer = Adam(model.parameters())
criterion = nn.CrossEntropyLoss(ignore_index=0)

In [32]:
def train(epoch, model, dataloader, n_batch ,optimizer, criterion, clip, device):
    model.train()
    train_loss = 0
    with tqdm(desc=f'Epoch:{epoch+1: 2d}',total=n_batch) as pbar:
        for i, (src, trg) in enumerate(dataloader):
            src = src.to(device)
            trg = trg.to(device)
            optimizer.zero_grad()
            output = model(src, trg)
            # output [trg_len-1, bs, output_dim]
            output_dim = output.shape[-1]
            output = output.view(-1, output_dim)
            trg = trg[1:].view(-1)
            loss = criterion(output, trg)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
            optimizer.step()
            train_loss += loss.item()
            avg_loss = train_loss/(i+1)
            pbar.set_postfix({'train_loss':avg_loss})
            pbar.update()
            if i == n_batch-1:
                break

In [33]:
def evaluate(epoch, model, dataloader, n_batch , criterion, device):
    model.eval()
    eval_loss = 0
    with tqdm(total=n_batch) as pbar:
        for i, (src, trg) in enumerate(dataloader):
            src = src.to(device)
            trg = trg.to(device)
            optimizer.zero_grad()
            output = model(src, trg)
            output_dim = output.shape[-1]
            output = output.view(-1, output_dim)
            trg = trg[1:].view(-1)
            loss = criterion(output, trg)
            eval_loss += loss.item()
            avg_loss = eval_loss/(i+1)
            pbar.set_postfix({'eval_loss':avg_loss})
            pbar.update()
            if i == n_batch-1:
                break
    return avg_loss

In [35]:
t_batch = math.ceil(len(to_map_style_dataset(train_dataset))//batch_size)
v_batch = math.ceil(len(to_map_style_dataset(val_dataset))//batch_size)
for epoch in range(N_EPOCHS):
    train(epoch, model, train_dataloader, t_batch, optimizer, criterion, CLIP, device)
    eval_loss = evaluate(epoch, model, val_dataloader, v_batch, criterion, device)
    if eval_loss < best_valid_loss:
        best_valid_loss = eval_loss
        torch.save(model.state_dict(), 'tut1-model.pt')

Epoch: 1: 100%|██████████| 226/226 [00:10<00:00, 21.01it/s, train_loss=4.99]
Epoch: 1: 100%|██████████| 7/7 [00:00<00:00, 52.43it/s, eval_loss=4.67]
Epoch: 2: 100%|██████████| 226/226 [00:10<00:00, 21.56it/s, train_loss=4.44]
Epoch: 2: 100%|██████████| 7/7 [00:00<00:00, 51.94it/s, eval_loss=4.15]
Epoch: 3: 100%|██████████| 226/226 [00:10<00:00, 21.44it/s, train_loss=4.13]
Epoch: 3: 100%|██████████| 7/7 [00:00<00:00, 53.72it/s, eval_loss=4.04]
Epoch: 4: 100%|██████████| 226/226 [00:10<00:00, 21.41it/s, train_loss=3.88]
Epoch: 4: 100%|██████████| 7/7 [00:00<00:00, 51.80it/s, eval_loss=3.74]
Epoch: 5: 100%|██████████| 226/226 [00:10<00:00, 21.53it/s, train_loss=3.66]
Epoch: 5: 100%|██████████| 7/7 [00:00<00:00, 50.87it/s, eval_loss=3.61]
Epoch: 6: 100%|██████████| 226/226 [00:10<00:00, 21.60it/s, train_loss=3.48]
Epoch: 6: 100%|██████████| 7/7 [00:00<00:00, 50.89it/s, eval_loss=3.41]
Epoch: 7: 100%|██████████| 226/226 [00:10<00:00, 21.47it/s, train_loss=3.29]
Epoch: 7: 100%|██████████| 7/