# BiLSTM + Attention Classifier

## imports

In [3]:
%load_ext lab_black

In [4]:
import sys

sys.path.append("..")

In [172]:
import pickle
import dill
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F

from functools import partial
from collections import Counter, defaultdict
from torch.autograd import Variable
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from utils.types_ import *

## 01. data load

In [6]:
data_path = "../data/tokenized/nouns_total_data.txt"
with open(data_path, "rb") as f:
    data = pickle.load(f)

In [19]:
# keyword, press, category, title, content
# columns = ["keyword", "press", "category", "title", "content"]
# df = pd.DataFrame(data, columns=columns)

In [217]:
text_len = [len(d[4]) for d in data]

In [218]:
np.mean(text_len)

237.42394667936205

## 02. build vocab

In [39]:
def build_vocab(
    dataset: List[Tuple], save_dir: str, num_words: int = 30000
) -> Union[Dict, Dict]:
    # 1. tokenization
    all_tokens = []
    for data in tqdm(dataset):
        all_tokens.extend(data[4])

    # 2. build vocab
    vocab = Counter(all_tokens)
    vocab = vocab.most_common(num_words)

    # 3. add pad & unk tokens
    word_index = defaultdict()
    word_index["<PAD>"] = 0
    word_index["<UNK>"] = 1

    for idx, (word, _) in enumerate(vocab, 2):
        word_index[word] = idx

    index_word = {idx: word for word, idx in word_index.items()}

    with open(f"{save_dir}/word_index.pkl", "wb") as f:
        dill.dump(word_index, f)

    return word_index, index_word

In [40]:
save_dir = "../data/vocab"
word_index, index_word = build_vocab(data, save_dir)

100%|██████████| 16804/16804 [00:00<00:00, 412672.05it/s]


## 03. Create Dataset & DataLoader

### 1) Dataset

In [150]:
# class NewsDataset(Dataset):
#     def __init__(
#         self,
#         path: str,
#         word_index: Dict,
#         max_len: int = 256,
#         labels: List = ["조선일보", "동아일보", "경향신문", "한겨레"],
#     ):
#         self.word_index = word_index
#         self.max_len = max_len
#         self.label_dict = {label: idx for idx, label in enumerate(labels)}

#         with open(path, "rb") as f:
#             self.data = pickle.load(f)

#     def __len__(self):
#         return len(self.data)

#     def __getitem__(self, idx):
#         keyword = self.data[idx][0]
#         label = self.data[idx][1]
#         label = self.label_dict[label]
#         text = self.data[idx][4]
#         sequence = self.text_to_sequence(text)
#         sequence = self.pad_sequence(sequence)
#         return sequence, label, keyword

#     def text_to_sequence(self, text):
#         sequence = [self.word_index.get(word, 1) for word in text]
#         return sequence

#     def pad_sequence(self, sequence):
#         if len(sequence) > self.max_len:
#             sequence = sequence[: self.max_len]
#         else:
#             sequence = [0 for _ in range(self.max_len - len(sequence))] + sequence
#         return sequence

In [166]:
class NewsDataset(Dataset):
    def __init__(self, path: str):
        with open(path, "rb") as f:
            self.data = pickle.load(f)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        keyword = self.data[idx][0]
        label = self.data[idx][1]
        text = self.data[idx][4]
        return text, label, keyword

In [167]:
data_path = "../data/tokenized/nouns_total_data.txt"
dataset = NewsDataset(data_path)

In [169]:
# dataset[0]

In [152]:
# len(dataset[200][0])

 ### 2) DataLoader

In [253]:
def collate_fn(batch, word_index, labels_dict, max_len=200):
    texts = [entry[0] for entry in batch]
    labels = [entry[1] for entry in batch]
    keywords = [entry[2] for entry in batch]

    sequences = []  # [[word_index.get(word, 1) for word in text] for text in texts]
    for text in texts:
        if len(text) > max_len:
            sequence = [word_index.get(word, 1) for word in text[:max_len]]
        else:
            sequence = [0 for _ in range(max_len - len(text))] + [
                word_index.get(word, 1) for word in text
            ]

        sequences.append(sequence)

    labels = [labels_dict[label] for label in labels]
    return sequences, labels, keywords

In [254]:
labels_list = ["조선일보", "동아일보", "경향신문", "한겨레"]
labels_dict = {label: idx for idx, label in enumerate(labels_list)}

dataloader = DataLoader(
    dataset=dataset,
    batch_size=2,
    shuffle=True,
    collate_fn=partial(collate_fn, word_index=word_index, labels_dict=labels_dict),
)

In [255]:
# for batch in dataloader:
#     sequences, labels, keywords = batch
#     break

In [81]:
# keywords

## BiLSTM + Attention Model

### 1) Attention Class

In [85]:
# class Attention(nn.Module):
#     def __init__(self, enc_hid_dim, dec_hid_dim):
#         super(Attention, self).__init__()

#         self.attn = nn.Linear(enc_hid_dim + dec_hid_dim, dec_hid_dim)
#         self.v = nn.Linear(dec_hid_dim, 1, bias=False)

#     def forward(self, hidden, encoder_outputs):
#         batch_size = encoder_outputs.size(0)
#         src_len = encoder_outputs.size(1)

#         hidden = hidden.unsqueeze(1).repeat(1, src_len, 1)

#         score = torch.tanh(self.attn(torch.cat((hidden, encoder_outputs), dim=2)))
#         attention = self.v(score).squeeze(2)

#         return F.softmax(attention, dim=1)

### 2) BiLSTM + Attention Class

In [256]:
class BiLSTMAttn(nn.Module):
    def __init__(
        self,
        vocab_size: int,
        num_class: int,
        embed_dim: int = 128,
        hidden_dim: int = 256,
        num_layers: int = 2,
        bidirectional: bool = True,
        dropout_p: float = 0.3,
    ):
        super(BiLSTMAttn, self).__init__()

        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.bidirectional = bidirectional
        self.num_directs = 1
        if bidirectional:
            self.num_directs = 2

        self.embed = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        nn.init.xavier_uniform_(self.embed.weight)
        self.bilstm = nn.LSTM(
            embed_dim,
            hidden_dim,
            num_layers=num_layers,
            batch_first=True,
            bidirectional=bidirectional,
            dropout=dropout_p,
        )
        self.linear = nn.Linear(hidden_dim * 2, num_class)
        self.dropout = nn.Dropout(dropout_p)

    def attention_layer(self, query, key):
        query = query.squeeze(0)
        attn_scores = torch.bmm(key, query.unsqueeze(2)).squeeze(2)
        attn_weights = F.softmax(attn_scores, dim=1)
        contexts = torch.bmm(key.transpose(1, 2), attn_weights.unsqueeze(2)).squeeze(2)
        return contexts, attn_weights

    def forward(self, sequence):
        x = self.embed(sequence)
        x = self.dropout(x)
        output, _ = self.bilstm(x)
        output, attn_weights = self.attention_layer(output[:, -1, :], output)
        output = self.linear(output)
        return output, attn_weights

In [257]:
# Device configuration
GPU_NUM = 1
DEVICE = torch.device(f"cuda:{GPU_NUM}" if torch.cuda.is_available() else "cpu")
DEVICE

device(type='cuda', index=1)

In [258]:
vocab_size = len(word_index)
num_class = 4

In [259]:
model = BiLSTMAttn(vocab_size, num_class).to(DEVICE)

In [260]:
for batch in dataloader:
    sequences, labels, keywords = batch
    break

sequences = torch.LongTensor(sequences).to(DEVICE)

In [261]:
model(sequences)

(tensor([[0.0396, 0.0111, 0.0128, 0.0278],
         [0.0387, 0.0108, 0.0126, 0.0280]], device='cuda:1',
        grad_fn=<AddmmBackward>),
 tensor([[0.0045, 0.0047, 0.0048, 0.0049, 0.0049, 0.0050, 0.0050, 0.0050, 0.0050,
          0.0050, 0.0050, 0.0050, 0.0050, 0.0050, 0.0050, 0.0050, 0.0050, 0.0050,
          0.0050, 0.0050, 0.0050, 0.0050, 0.0050, 0.0050, 0.0050, 0.0050, 0.0050,
          0.0050, 0.0050, 0.0050, 0.0050, 0.0050, 0.0050, 0.0050, 0.0050, 0.0050,
          0.0050, 0.0050, 0.0050, 0.0050, 0.0050, 0.0050, 0.0050, 0.0050, 0.0050,
          0.0050, 0.0050, 0.0050, 0.0050, 0.0050, 0.0050, 0.0050, 0.0050, 0.0050,
          0.0050, 0.0050, 0.0050, 0.0050, 0.0050, 0.0050, 0.0050, 0.0050, 0.0050,
          0.0050, 0.0050, 0.0050, 0.0050, 0.0050, 0.0050, 0.0050, 0.0050, 0.0050,
          0.0050, 0.0050, 0.0050, 0.0050, 0.0050, 0.0050, 0.0050, 0.0050, 0.0050,
          0.0050, 0.0050, 0.0050, 0.0050, 0.0050, 0.0050, 0.0050, 0.0050, 0.0050,
          0.0050, 0.0050, 0.0050, 0.0050, 

In [None]:
model.