# Train BiLSTM + Attn Model

## imports

In [1]:
%load_ext lab_black

In [2]:
import sys

sys.path.append("..")

In [3]:
import pickle
from functools import partial
from collections import OrderedDict

import yaml
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader

import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.loggers import TestTubeLogger  # pip install test-tube

from models import BiLSTMAttn
from utils import NewsDataset, collate_fn
from utils.types_ import *

In [4]:
# Device configuration
GPU_NUM = 1
DEVICE = torch.device(f"cuda:{GPU_NUM}" if torch.cuda.is_available() else "cpu")

seed = 42
np.random.seed(seed)
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

## Experiment Class with PL

In [5]:
class Experiment(pl.LightningModule):
    def __init__(self, model, params):
        super(Experiment, self).__init__()

        self.model = model
        self.params = params
        self._loss = nn.CrossEntropyLoss()

    # ---------------------
    # TRAINING
    # ---------------------
    def forward(self, sequences):
        return self.model(sequences)[0]

    def loss_function(self, preds, labels):
        ce_loss = self._loss(preds, labels)
        return ce_loss

    def training_step(self, batch, batch_idx):
        sequences, labels, keywords = batch

        preds = self.forward(sequences)
        train_loss = self.loss_function(preds, labels)
        tqdm_dict = {"train_CEE": train_loss}

        output = OrderedDict(
            {"loss": train_loss, "progress_bar": tqdm_dict, "log": tqdm_dict}
        )
        return output

    def validation_step(self, batch, batch_idx):
        sequences, labels, keywords = batch

        preds = self.forward(sequences)
        val_loss = self.loss_function(preds, labels)

        output = OrderedDict({"val_loss": val_loss,})
        return output

    def validation_epoch_end(self, outputs):
        val_loss_mean = torch.stack([x["val_loss"] for x in outputs]).mean()
        return {"val_loss": val_loss_mean}

    def test_step(self, batch, batch_idx):
        sequences, labels, keywords = batch

        preds = self.forward(sequences)
        test_loss = self.loss_function(preds, labels)
        return {"test_loss": test_loss}

    def test_epoc_end(self, outputs):
        val_loss_mean = torch.stack([x["test_loss"] for x in outputs]).mean()
        return {"test_loss": val_loss_mean}

    # ---------------------
    # TRAINING SETUP
    # ---------------------
    def configure_optimizers(self):
        return torch.optim.Adam(
            self.model.parameters(), lr=self.params["LR"], weight_decay=1e-5
        )

## Train

In [6]:
config_path = "./config.yaml"
with open(config_path, "r") as file:
    try:
        config = yaml.safe_load(file)
    except yaml.YAMLError as exc:
        print(exc)

In [7]:
config

{'model_params': {'embed_dim': 128,
  'hidden_dim': 256,
  'num_layers': 2,
  'bidirectional': True,
  'dropout_p': 0.3},
 'exp_params': {'data_path': '../data/tokenized/nouns_total_data.txt',
  'vocab_path': '../data/vocab/word_index.pkl',
  'batch_size': 64,
  'LR': 0.0001},
 'trainer_params': {'gpus': 1, 'max_epochs': 30},
 'logging_params': {'save_dir': 'logs/',
  'name': 'BiLSTMAttn',
  'manual_seed': 42}}

In [8]:
# ----------------
# DataLoader
# ----------------
data_path = config["exp_params"]["data_path"]
vocab_path = config["exp_params"]["vocab_path"]
labels_list = ["조선일보", "동아일보", "경향신문", "한겨레"]
labels_dict = {label: idx for idx, label in enumerate(labels_list)}

with open(vocab_path, "rb") as f:
    word_index = pickle.load(f)


dataset = NewsDataset(data_path)

train_loader = DataLoader(
    dataset=dataset,
    batch_size=64,
    shuffle=True,
    collate_fn=partial(collate_fn, word_index=word_index, labels_dict=labels_dict),
)

dev_loader = DataLoader(
    dataset=dataset,
    batch_size=64,
    shuffle=False,
    collate_fn=partial(collate_fn, word_index=word_index, labels_dict=labels_dict),
)

In [9]:
# ----------------
# SetUp Model
# ----------------

# vocab_size
config["model_params"]["vocab_size"] = len(word_index)
# num_class
config["model_params"]["num_class"] = len(labels_list)

model = BiLSTMAttn(**config["model_params"])
experiment = Experiment(model, config["exp_params"])

In [10]:
# model

In [11]:
# ----------------
# TestTubeLogger
# ----------------
tt_logger = TestTubeLogger(
    save_dir=config["logging_params"]["save_dir"],
    name=config["logging_params"]["name"],
    debug=False,
    create_git_tag=False,
)

In [12]:
# ----------------
# Checkpoint
# ----------------
checkpoint_callback = ModelCheckpoint(
    filepath="./checkpoints/lstm_reg_{epoch:02d}_{val_loss:.2f}",
    monitor="val_loss",
    verbose=True,
    save_top_k=5,
)

early_stopping = EarlyStopping(monitor="val_loss", patience=5, verbose=True)

EarlyStopping mode auto is unknown, fallback to auto mode.
EarlyStopping mode set to min for monitoring val_loss.


In [13]:
# ----------------
# Trainer
# ----------------

runner = Trainer(
    default_save_path=f"{tt_logger.save_dir}",
    min_epochs=1,
    logger=tt_logger,
    log_save_interval=100,
    train_percent_check=1.0,
    val_percent_check=1.0,
    num_sanity_val_steps=5,
    early_stop_callback=early_stopping,
    checkpoint_callback=checkpoint_callback,
    **config["trainer_params"],
)

GPU available: True, used: True
No environment variable for node rank defined. Set as 0.
CUDA_VISIBLE_DEVICES: [0]


In [14]:
# ----------------
# Start Train
# ----------------
runner.fit(experiment, train_loader, dev_loader)


  | Name          | Type             | Params
-----------------------------------------------
0 | model         | BiLSTMAttn       | 6 M   
1 | model.embed   | Embedding        | 3 M   
2 | model.bilstm  | LSTM             | 2 M   
3 | model.linear  | Linear           | 2 K   
4 | model.dropout | Dropout          | 0     
5 | _loss         | CrossEntropyLoss | 0     


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…


Epoch 00000: val_loss reached 1.29489 (best 1.29489), saving model to ./checkpoints/lstm_reg_epoch=00_val_loss=1.29.ckpt as top 5


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…


Epoch 00001: val_loss reached 1.15086 (best 1.15086), saving model to ./checkpoints/lstm_reg_epoch=01_val_loss=1.15.ckpt as top 5


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…


Epoch 00002: val_loss reached 1.03194 (best 1.03194), saving model to ./checkpoints/lstm_reg_epoch=02_val_loss=1.03.ckpt as top 5


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…


Epoch 00003: val_loss reached 0.89760 (best 0.89760), saving model to ./checkpoints/lstm_reg_epoch=03_val_loss=0.90.ckpt as top 5


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…


Epoch 00004: val_loss reached 0.78696 (best 0.78696), saving model to ./checkpoints/lstm_reg_epoch=04_val_loss=0.79.ckpt as top 5


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…


Epoch 00005: val_loss reached 0.71318 (best 0.71318), saving model to ./checkpoints/lstm_reg_epoch=05_val_loss=0.71.ckpt as top 5


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…


Epoch 00006: val_loss reached 0.64143 (best 0.64143), saving model to ./checkpoints/lstm_reg_epoch=06_val_loss=0.64.ckpt as top 5


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…


Epoch 00007: val_loss reached 0.56242 (best 0.56242), saving model to ./checkpoints/lstm_reg_epoch=07_val_loss=0.56.ckpt as top 5


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…


Epoch 00008: val_loss reached 0.47350 (best 0.47350), saving model to ./checkpoints/lstm_reg_epoch=08_val_loss=0.47.ckpt as top 5


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…


Epoch 00009: val_loss reached 0.37095 (best 0.37095), saving model to ./checkpoints/lstm_reg_epoch=09_val_loss=0.37.ckpt as top 5


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…


Epoch 00010: val_loss reached 0.30097 (best 0.30097), saving model to ./checkpoints/lstm_reg_epoch=10_val_loss=0.30.ckpt as top 5


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…


Epoch 00011: val_loss reached 0.26754 (best 0.26754), saving model to ./checkpoints/lstm_reg_epoch=11_val_loss=0.27.ckpt as top 5


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…


Epoch 00012: val_loss reached 0.22568 (best 0.22568), saving model to ./checkpoints/lstm_reg_epoch=12_val_loss=0.23.ckpt as top 5


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…


Epoch 00013: val_loss reached 0.20668 (best 0.20668), saving model to ./checkpoints/lstm_reg_epoch=13_val_loss=0.21.ckpt as top 5


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…


Epoch 00014: val_loss reached 0.16296 (best 0.16296), saving model to ./checkpoints/lstm_reg_epoch=14_val_loss=0.16.ckpt as top 5


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…


Epoch 00015: val_loss reached 0.21529 (best 0.16296), saving model to ./checkpoints/lstm_reg_epoch=15_val_loss=0.22.ckpt as top 5


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…


Epoch 00016: val_loss reached 0.12208 (best 0.12208), saving model to ./checkpoints/lstm_reg_epoch=16_val_loss=0.12.ckpt as top 5


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…


Epoch 00017: val_loss reached 0.10982 (best 0.10982), saving model to ./checkpoints/lstm_reg_epoch=17_val_loss=0.11.ckpt as top 5


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…


Epoch 00018: val_loss reached 0.10229 (best 0.10229), saving model to ./checkpoints/lstm_reg_epoch=18_val_loss=0.10.ckpt as top 5


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…


Epoch 00019: val_loss reached 0.13028 (best 0.10229), saving model to ./checkpoints/lstm_reg_epoch=19_val_loss=0.13.ckpt as top 5


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…


Epoch 00020: val_loss reached 0.08218 (best 0.08218), saving model to ./checkpoints/lstm_reg_epoch=20_val_loss=0.08.ckpt as top 5


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…


Epoch 00021: val_loss reached 0.08498 (best 0.08218), saving model to ./checkpoints/lstm_reg_epoch=21_val_loss=0.08.ckpt as top 5


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…


Epoch 00022: val_loss reached 0.07316 (best 0.07316), saving model to ./checkpoints/lstm_reg_epoch=22_val_loss=0.07.ckpt as top 5


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…


Epoch 00023: val_loss reached 0.06691 (best 0.06691), saving model to ./checkpoints/lstm_reg_epoch=23_val_loss=0.07.ckpt as top 5


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…


Epoch 00024: val_loss reached 0.06709 (best 0.06691), saving model to ./checkpoints/lstm_reg_epoch=24_val_loss=0.07.ckpt as top 5


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…


Epoch 00025: val_loss reached 0.06048 (best 0.06048), saving model to ./checkpoints/lstm_reg_epoch=25_val_loss=0.06.ckpt as top 5


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…


Epoch 00026: val_loss reached 0.05401 (best 0.05401), saving model to ./checkpoints/lstm_reg_epoch=26_val_loss=0.05.ckpt as top 5


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…


Epoch 00027: val_loss reached 0.05365 (best 0.05365), saving model to ./checkpoints/lstm_reg_epoch=27_val_loss=0.05.ckpt as top 5


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…


Epoch 00028: val_loss reached 0.05551 (best 0.05365), saving model to ./checkpoints/lstm_reg_epoch=28_val_loss=0.06.ckpt as top 5


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…


Epoch 00029: val_loss reached 0.05032 (best 0.05032), saving model to ./checkpoints/lstm_reg_epoch=29_val_loss=0.05.ckpt as top 5





1