# Train BiLSTM + Attn Model

## imports

In [1]:
%load_ext lab_black

In [2]:
import sys

sys.path.append("..")

In [88]:
import pickle
from functools import partial
from collections import OrderedDict, defaultdict

import yaml
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader

import pytorch_lightning as pl

from experiment import Experiment
from models import BiLSTMAttn
from utils import NewsDataset, collate_fn
from utils.types_ import *

from tqdm import tqdm

In [11]:
import warnings

warnings.filterwarnings("ignore")

In [6]:
# Device configuration
GPU_NUM = 1
DEVICE = torch.device(f"cuda:{GPU_NUM}" if torch.cuda.is_available() else "cpu")

seed = 42
np.random.seed(seed)
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

## data load

In [7]:
config_path = "./config.yaml"
with open(config_path, "r") as file:
    try:
        config = yaml.safe_load(file)
    except yaml.YAMLError as exc:
        print(exc)

In [8]:
# ----------------
# DataLoader
# ----------------
data_path = config["exp_params"]["data_path"]
vocab_path = config["exp_params"]["vocab_path"]
labels_list = ["조선일보", "동아일보", "경향신문", "한겨레"]
labels_dict = {label: idx for idx, label in enumerate(labels_list)}

with open(vocab_path, "rb") as f:
    word_index = pickle.load(f)


dataset = NewsDataset(data_path)

test_loader = DataLoader(
    dataset=dataset,
    batch_size=64,
    shuffle=False,
    collate_fn=partial(collate_fn, word_index=word_index, labels_dict=labels_dict),
)

## Model load

In [13]:
ckpt_path = "../checkpoints/BilstmAttn_epoch=29_val_loss=0.05.ckpt"

checkpoint = torch.load(ckpt_path)
checkpoint["state_dict"] = OrderedDict(
    [(key.replace("model.", ""), val) for key, val in checkpoint["state_dict"].items()]
)

In [17]:
# vocab_size & num_class
config["model_params"]["vocab_size"] = len(word_index)
config["model_params"]["num_class"] = len(labels_list)

model = BiLSTMAttn(**config["model_params"]).to(DEVICE)
model.load_state_dict(checkpoint["state_dict"])
model.eval()

BiLSTMAttn(
  (embed): Embedding(30002, 128, padding_idx=0)
  (bilstm): LSTM(128, 256, num_layers=2, batch_first=True, dropout=0.3, bidirectional=True)
  (linear): Linear(in_features=512, out_features=4, bias=True)
  (dropout): Dropout(p=0.3, inplace=False)
)

## Test

In [132]:
top_k = 20
index_word = {idx: word for word, idx in word_index.items()}
index_label = {idx: label for label, idx in labels_dict.items()}
result_dict = defaultdict(lambda: defaultdict(list))

In [142]:
# for batch in tqdm(test_loader):
#     sequences, labels, keywords = batch
#     sequences = sequences.to(DEVICE)

#     labels = [index_label[label] for label in labels.tolist()]
#     _, attn_scores = model(sequences)

#     for keyword, label, attn_score, sequence in zip(
#         keywords, labels, attn_scores, sequences
#     ):

#         topk_attns, topk_idxs = torch.topk(attn_score, top_k)
#         topk_attns = topk_attns.tolist()
#         topk_seq = sequence[topk_idxs].tolist()

#         result = [
#             (index_word[seq], score)
#             for seq, score in zip(topk_seq, topk_attns)
#             if seq > 1
#         ]
#         result_dict[keyword][label].extend(result)

In [241]:
top_k = 20
index_word = {idx: word for word, idx in word_index.items()}
index_label = {idx: label for label, idx in labels_dict.items()}
result_dict = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))

for batch in tqdm(test_loader):
    sequences, labels, keywords = batch
    sequences = sequences.to(DEVICE)

    labels = [index_label[label] for label in labels.tolist()]
    _, attn_scores = model(sequences)

    for keyword, label, attn_score, sequence in zip(
        keywords, labels, attn_scores, sequences
    ):

        topk_attns, topk_idxs = torch.topk(attn_score, top_k)
        topk_attns = topk_attns.tolist()
        topk_seq = sequence[topk_idxs].tolist()

        result = [
            (index_word[seq], score)
            for seq, score in zip(topk_seq, topk_attns)
            if seq > 1
        ]

        for word, score in result:
            if len(word) > 1:
                result_dict[keyword][label][word].append(score)

100%|██████████| 263/263 [00:10<00:00, 25.40it/s]


In [245]:
media_keyword_dict = defaultdict(lambda: defaultdict(lambda: defaultdict(float)))

In [247]:
for keyword, media_dict in result_dict.items():
    for media, word_dict in media_dict.items():
        for word, vals in word_dict.items():
            media_keyword_dict[keyword][media][word] = np.mean(vals)

In [260]:
tmp = media_keyword_dict["남북회담"]["한겨레"]

In [261]:
sorted_words = sorted(tmp, key=lambda x: x[1], reverse=True)

In [262]:
sorted_words[:400]

['김희중',
 '원희룡',
 '문희상',
 '정희영',
 '정희곤',
 '즉흥',
 '부흥',
 '진흥',
 '호흡',
 '사흘',
 '열흘',
 '나흘',
 '마흔',
 '연휴',
 '발휘',
 '지휘관',
 '지휘',
 '장휘국',
 '폄훼',
 '서훈',
 '관훈클럽',
 '교훈',
 '유훈',
 '직후',
 '이후',
 '막후',
 '추후',
 '향후',
 '선후',
 '기후',
 '사후',
 '윤후덕',
 '이후락',
 '노후',
 '조효제',
 '기획',
 '계획',
 '기획력',
 '기획사',
 '대회',
 '총회',
 '개회',
 '학회',
 '국회',
 '사회',
 '부회장',
 '보회',
 '선회',
 '사회관',
 '기회',
 '민회',
 '입회',
 '의회',
 '출회',
 '국회의원',
 '사회주의',
 '사회단체',
 '인회',
 '본회',
 '면회소',
 '순회',
 '교회',
 '철회',
 '본회의',
 '노회찬',
 '협회',
 '폐회식',
 '연회',
 '자회',
 '집회',
 '소회',
 '비회원국',
 '국회의사당',
 '상황실',
 '교황',
 '상황',
 '불황',
 '교황청',
 '현황',
 '생활관',
 '생활',
 '반환',
 '일환',
 '전환점',
 '교환',
 '순환',
 '귀환',
 '봉환',
 '송환',
 '질환',
 '소환',
 '소환장',
 '중환자실',
 '전환',
 '수확',
 '정확',
 '비확산',
 '강화',
 '평화',
 '변화',
 '전화',
 '완화',
 '활화산',
 '광화문',
 '대화',
 '평화통일',
 '민화협',
 '문화',
 '풍화',
 '평화상',
 '통화',
 '공화국',
 '문화유산',
 '문화재',
 '인화',
 '칭화',
 '문화원',
 '판화',
 '백화',
 '국화',
 '문화제',
 '약화',
 '영화제',
 '중화',
 '영화',
 '악화',
 '성화',
 '문화전',
 '헌화',
 '둔화',
 '문화방송',
 '영화관',
 '격화',
 '일화',
 '원화',