In [14]:
import torch
import torch.nn as nn
from torch import optim

from chameleon.dl_dataset import Vocabulary, TranslationDataset, TranslationCollator
from chameleon.models.lstm_lm import LanguageModel
from argparse import Namespace

import pprint

In [15]:
import pandas as pd


def get_data():
    # load data

    train_data = pd.read_pickle(DATA_ROOT + "chameleon.train.tok.pickle")
    valid_data = pd.read_pickle(DATA_ROOT + "chameleon.valid.tok.pickle")
    test_data = pd.read_pickle(DATA_ROOT + "chameleon.test.tok.pickle")
    return train_data, valid_data, test_data


DATA_ROOT = "./data/"
train_data, valid_data, test_data = get_data()

In [16]:
def get_config():
    config = {
        "lang": "enko",
        "gpu_id": 0,
        "use_mps": True,
        "batch_size": 16,
        "n_epochs": 2,
        "max_length": 512,
        "dropout": 0.2,
        "word_vec_size": 512,
        "hidden_size": 768,
        "n_layers": 4,
        "max_grad_norm": 1e8,
    }
    config = Namespace(**config)
    return config


config = get_config()

In [17]:
from torch.utils.data import DataLoader


def get_loaders(config):
    # specify source and target language
    src_lang = config.lang[:2]
    tgt_lang = config.lang[-2:]
    print(f"source language: {src_lang}, target language: {tgt_lang}")

    # DataLoader for source and target language
    train_loader = DataLoader(
        TranslationDataset(
            srcs=train_data[f"tok_{src_lang}"].tolist(),
            tgts=train_data[f"tok_{tgt_lang}"].tolist(),
            with_text=False,
            is_dual=True,  # tgt dataset also needs BOS and EOS token at the begging and the end.
        ),
        batch_size=config.batch_size,
        shuffle=True,
        collate_fn=TranslationCollator(
            pad_idx=Vocabulary.PAD,
            eos_idx=Vocabulary.EOS,
            max_length=config.max_length,
            with_text=False,
            is_dual=True,
        ),
    )

    train_src_vocab = train_loader.dataset.src_vocab
    train_tgt_vocab = train_loader.dataset.tgt_vocab

    valid_loader = DataLoader(
        TranslationDataset(
            srcs=valid_data[f"tok_{src_lang}"].tolist(),
            tgts=valid_data[f"tok_{tgt_lang}"].tolist(),
            src_vocab=train_src_vocab,
            tgt_vocab=train_tgt_vocab,
            with_text=False,
            is_dual=True,
        ),
        batch_size=config.batch_size,
        shuffle=False,
        collate_fn=TranslationCollator(
            pad_idx=Vocabulary.PAD,
            eos_idx=Vocabulary.EOS,
            max_length=config.max_length,
            with_text=False,
            is_dual=True,
        ),
    )

    return train_loader, valid_loader


train_loader, valid_loader = get_loaders(config)

## ----- Test data_loader ----- #
# # input_ids would be dataset for source LM
# # output_ids would be dataset for target LM
# test_batch = next(iter(train_loader))
# print(test_batch)

source language: en, target language: ko
Number of vocabularies:  30488
Number of vocabularies:  53430


In [18]:
# get vocabulary size for source and target language
src_vocab_size, tgt_vocab_size = (
    len(train_loader.dataset.src_vocab),
    len(train_loader.dataset.tgt_vocab),
)

In [19]:
# get LanguageModel for source language only here
# later in actual implementation,
# we are going to load two LMs, each for source and target
def get_model(config):
    src_model = LanguageModel(
        vocab_size=src_vocab_size,
        word_vec_size=config.word_vec_size,
        hidden_size=config.hidden_size,
        n_layers=config.n_layers,
        dropout=config.dropout,
        max_length=config.max_length,
    )
    return src_model


src_model = get_model(config)

In [20]:
print(src_model)

LanguageModel(
  (emb): Embedding(30488, 512, padding_idx=0)
  (rnn): LSTM(512, 768, num_layers=4, batch_first=True, dropout=0.2)
  (out): Linear(in_features=768, out_features=30488, bias=True)
  (log_softmax): LogSoftmax(dim=-1)
)


In [21]:
def get_crit(src_vocab_size, pad_idx):
    """[TODO]
    If training the LMs at once,
    we need to return two criterions.
    """
    loss_weight = torch.ones(src_vocab_size)
    loss_weight[pad_idx] = 0.0

    crit = nn.NLLLoss(weight=loss_weight, reduction="none")

    return crit


crit = get_crit(src_vocab_size, Vocabulary.PAD)

In [22]:
# load model to gpu
if config.gpu_id >= 0 and not config.use_mps:
    src_model.cuda(config.gpu_id)
    # Reason we need to load the criterion to GPU
    # https://discuss.pytorch.org/t/move-the-loss-function-to-gpu/20060/5
    # A weight parameter could be seen as an internal state and would yield a device mismatch error.
    # Of course you might define the weight parameter as a CUDATensor, but you could also move the criterion to the device
    crit.cuda(config.gpu_id)
elif config.use_mps:
    src_model.to("mps:{}".format(config.gpu_id))
    crit.to("mps:{}".format(config.gpu_id))

In [23]:
next(src_model.parameters()).device

device(type='mps', index=0)

In [24]:
# load optimizer
optimizer = optim.Adam(src_model.parameters())

### Test tqdm ProgressBar

In [25]:
# import time
# import numpy as np

# n = 1000
# i = 0
# while i<n:
#     time.sleep(0.1)
#     value = np.random.randint(1, 100)
#     i += value
#     print(i, end = "\r")

# print("Loop Completed.")

In [26]:
# import time
# import numpy as np
# from tqdm import tqdm

# n = 1000
# i = 0
# pbar = tqdm(desc="while loop", total=n)
# while i<n:
#     time.sleep(0.1)
#     value = np.random.randint(1, 100)
#     i += value
#     pbar.update(value)
#     pbar.set_postfix(loss = i)
#     # print(i, end = "\r")

# print("Loop Completed.")
# pbar.close()

In [27]:
# iteration = 1000
# pbar = tqdm(range(iteration), desc="Epoch - 1", total = iteration)
# loss = 0
# for i in pbar:
#     loss += 2
#     pbar.set_postfix(loss = loss)
# pbar.close()
# print(f"Epoch - {loss / 30}")

### Test model forward

In [28]:
# test_batch = next(iter(train_loader))
# src_model.eval()

# with torch.no_grad():

#     device = next(src_model.parameters()).device
#     test_batch["input_ids"] = (
#         test_batch["input_ids"][0].to(device),
#         test_batch["input_ids"][1]
#     )

#     x = test_batch["input_ids"][0][:, :-1]
#     y = test_batch["input_ids"][0][:, 1:]
#     print(x.size(), y.size())

#     # forward
#     y_hat = src_model(x)
#     # |y_hat| = (batch_size, length, output_size)
#     print(y_hat)
#     print(y_hat.size())

#     # calculate loss
#     loss = crit(
#         y_hat.contiguous().view(-1, y_hat.size(-1)),
#         y.contiguous().view(-1),
#     ).sum() # criterion - reduction = None
#     print(loss)

In [None]:
# define Trainer
import numpy as np
from tqdm import tqdm
from copy import deepcopy


class Trainer:
    def __init__(
        self,
        model,
        crit,
        optimizer,
        train_loader,
        valid_loader,
        src_vocab,
        tgt_vocab,
        config,
    ):
        self.model = model
        self.crit = crit
        self.optimizer = optimizer
        self.train_loader = train_loader
        self.valid_loader = valid_loader
        self.src_vocab = src_vocab
        self.tgt_vocab = tgt_vocab
        self.config = config

        self.latest_loss = None
        self.best_loss = None
        self.best_model = None

    def _train(self, epoch):
        device = next(self.model.parameters()).device

        # # initialize pbar
        # pbar = tqdm(
        #     self.train_loader,
        #     desc=f"Training Epoch - {epoch}",
        #     total = len(self.train_loader)
        # )

        total_loss = 0
        for idx, mini_batch in enumerate(self.train_loader):
            self.model.train()
            self.optimizer.zero_grad()

            mini_batch["input_ids"] = (
                mini_batch["input_ids"][0].to(device),
                mini_batch["input_ids"][1],
            )

            x = mini_batch["input_ids"][0][:, :-1]
            y = mini_batch["input_ids"][0][:, 1:]

            # forward
            y_hat = self.model(x)
            # |y_hat| = (batch_size, length, output_size)

            # calculate loss
            loss = self.crit(
                y_hat.contiguous().view(-1, y_hat.size(-1)),
                y.contiguous().view(-1),
            ).sum()  # criterion - reduction = None
            backward_target = loss.div(y.size(0))

            # backward loss
            backward_target.backward()

            word_count = int(mini_batch["input_ids"][1].sum())

            # graient clipping
            nn.utils.clip_grad_norm_(self.model.parameters(), self.config.max_grad_norm)

            # update model parameters
            self.optimizer.step()

            loss = float(loss / word_count)
            ppl = np.exp(loss)

            total_loss += float(loss)

            if (idx + 1) % 100 == 0:
                print(
                    f"Epoch - {epoch} - {idx+1}/{len(self.train_loader)} - loss: {loss}, ppl: {ppl}"
                )

            # # update pbar
            # pbar.set_postfix(loss = loss, ppl = ppl)

        return total_loss / len(self.train_loader)

    def _validate(self, epoch):
        device = next(self.model.parameters()).device

        # # initialize pbar
        # pbar = tqdm(
        #     self.train_loader,
        #     desc=f"Validation Epoch - {epoch}",
        #     total = len(self.valid_loader)
        # )

        with torch.no_grad():
            total_loss = 0
            for idx, mini_batch in enumerate(self.valid_loader):
                self.model.eval()
                mini_batch["input_ids"] = (
                    mini_batch["input_ids"][0].to(device),
                    mini_batch["input_ids"][1],
                )

                x = mini_batch["input_ids"][0][:, :-1]
                y = mini_batch["input_ids"][0][:, 1:]

                # forward
                y_hat = self.model(x)
                # |y_hat| = (batch_size, length, output_size)

                # calculate loss
                loss = self.crit(
                    y_hat.contiguous().view(-1, y_hat.size(-1)), y.contiguous().view(-1)
                ).sum()  # criterion - reduction = 'none'

                word_count = int(mini_batch["input_ids"][1].sum())
                loss = float(loss / word_count)
                ppl = np.exp(loss)

                total_loss += float(loss)

                # pbar.set_postfix(loss = loss, ppl = ppl)
                if (idx + 1) % 100 == 0:
                    print(
                        f"Epoch - {epoch} - {idx+1}/{len(self.valid_loader)} - loss: {loss}, ppl: {ppl}"
                    )

        return total_loss / len(self.valid_loader)

    def train(self):
        for epoch in range(config.n_epochs):
            train_loss = self._train(epoch)
            valid_loss = self._validate(epoch)

            print(
                f"Epoch - {epoch+1}: train_loss: {train_loss} ; valid_loss: {valid_loss}"
            )

            # update latest_loss
            self.latest_loss = valid_loss

            # check best model and copy it to self.model
            self.check_best()

        # Restore to best model.
        self.model.load_state_dict(self.best_model)

        return self.model

    def check_best(self):
        loss = self.latest_loss
        if self.best_loss is None:
            self.best_loss = loss

        if loss <= self.best_loss:
            self.best_loss = loss
            self.best_model = deepcopy(self.model.state_dict())


LMTrainer = Trainer(
    model=src_model,
    crit=crit,
    optimizer=optimizer,
    train_loader=train_loader,
    valid_loader=valid_loader,
    src_vocab=None,
    tgt_vocab=None,
    config=config,
)

model = LMTrainer.train()