In [1]:
import torch
from torch import nn
import numpy as np
from pathlib import Path

torch.backends.cudnn.benchmark = True
np.random.seed(42)
torch.manual_seed(42)
torch.cuda.manual_seed(42)

import sys
sys.path.append("..")

from nmt.models import NMT
from nmt.datasets import read_corpus, batch_iter, Vocab
import tqdm

In [2]:
data_loc = Path("..") / "nmt" / "datasets" / "data"
en_es_data_loc = data_loc / "en_es_data"
train_data_src_path = en_es_data_loc / "train_tiny.es"
train_data_tgt_path = en_es_data_loc / "train_tiny.en"
dev_data_src_path = en_es_data_loc / "dev_tiny.es"
dev_data_tgt_path = en_es_data_loc / "dev_tiny.en"
vocab_path = data_loc / "vocab_tiny_q2.json"

In [3]:
train_src = read_corpus(train_data_src_path)
train_tgt = read_corpus(train_data_tgt_path, is_target=True)
vocab = Vocab.load(vocab_path)

In [4]:
BATCH_SIZE=2
MAX_EPOCH=201
SEED=42
EMBEDDING_SIZE=256
HIDDEN_SIZE=256
GRAD_CLIP=5.0
UNIFORM_INIT=0.1
USE_CHAR_DECODER=True
LEARNING_RATE=0.001

In [5]:
model = NMT(
    vocab=vocab,
    embedding_dim=EMBEDDING_SIZE,
    hidden_size=HIDDEN_SIZE,
    use_char_decoder=True
)
model.train()

NMT(
  (encoder): Encoder(
    (embedding): CharEmbedding(
      (char_embed): Embedding(97, 50, padding_idx=0)
      (cnn_embed): CharCNNEmbedding(
        (conv): Conv1d(50, 256, kernel_size=(5,), stride=(1,))
        (maxpool): AdaptiveMaxPool1d(output_size=1)
      )
      (highway): Highway(
        (linear): Linear(in_features=256, out_features=256, bias=True)
        (gate): Linear(in_features=256, out_features=256, bias=True)
      )
      (dropout): Dropout(p=0.3, inplace=False)
    )
    (encoder): LSTM(256, 256, num_layers=2, bidirectional=True)
    (hidden_projection): Linear(in_features=512, out_features=256, bias=False)
    (cell_projection): Linear(in_features=512, out_features=256, bias=False)
  )
  (decoder): Decoder(
    (embedding): CharEmbedding(
      (char_embed): Embedding(97, 50, padding_idx=0)
      (cnn_embed): CharCNNEmbedding(
        (conv): Conv1d(50, 256, kernel_size=(5,), stride=(1,))
        (maxpool): AdaptiveMaxPool1d(output_size=1)
      )
      (hig

In [6]:
uniform_init = UNIFORM_INIT
if np.abs(uniform_init) > 0.:
    print('uniformly initialize parameters [-%f, +%f]' %
            (uniform_init, uniform_init), file=sys.stderr)
    for p in model.parameters():
        p.data.uniform_(-uniform_init, uniform_init)

uniformly initialize parameters [-0.100000, +0.100000]


In [7]:
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

In [8]:
epoch = 0

In [10]:
for epoch in range(MAX_EPOCH):
    cum_loss = 0
    for i, (src_sents, tgt_sents) in enumerate(batch_iter((train_src, train_tgt), batch_size=BATCH_SIZE, shuffle=True)):
        optimizer.zero_grad()
        batch_size = len(src_sents)

        batch_loss = -model(src_sents, tgt_sents).sum()
        batch_loss /= batch_size
        cum_loss += batch_loss
        batch_loss.backward()

         # clip gradient
        grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
        optimizer.step()
    cum_loss /= len(train_src)
    print(f"Epoch: {str(epoch).zfill(3)} - Cumulative loss: {cum_loss}")

Epoch: 000 - Cumulative loss: 72.90512084960938
Epoch: 001 - Cumulative loss: 69.88958740234375
Epoch: 002 - Cumulative loss: 68.48753356933594
Epoch: 003 - Cumulative loss: 66.44123840332031
Epoch: 004 - Cumulative loss: 63.33734130859375
Epoch: 005 - Cumulative loss: 60.675079345703125
Epoch: 006 - Cumulative loss: 57.88020706176758
Epoch: 007 - Cumulative loss: 54.3707389831543
Epoch: 008 - Cumulative loss: 51.51618194580078
Epoch: 009 - Cumulative loss: 49.38239669799805
Epoch: 010 - Cumulative loss: 48.00582504272461
Epoch: 011 - Cumulative loss: 47.20466995239258
Epoch: 012 - Cumulative loss: 46.401790618896484
Epoch: 013 - Cumulative loss: 46.17249298095703
Epoch: 014 - Cumulative loss: 45.65177917480469
Epoch: 015 - Cumulative loss: 45.349369049072266
Epoch: 016 - Cumulative loss: 44.933319091796875
Epoch: 017 - Cumulative loss: 44.613555908203125
Epoch: 018 - Cumulative loss: 44.37868118286133
Epoch: 019 - Cumulative loss: 44.22948455810547
Epoch: 020 - Cumulative loss: 44.304