# Train BiLSTM + Attn Model

## imports

In [1]:
%load_ext lab_black

In [2]:
import sys

sys.path.append("..")

In [44]:
import pickle
from functools import partial
from collections import OrderedDict, defaultdict

import yaml
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader

import pytorch_lightning as pl

from experiment import Experiment
from models import BiLSTMAttn
from utils import NewsDataset, collate_fn
from utils.types_ import *

from tqdm import tqdm

In [4]:
import warnings

warnings.filterwarnings("ignore")

In [5]:
# Device configuration
GPU_NUM = 1
DEVICE = torch.device(f"cuda:{GPU_NUM}" if torch.cuda.is_available() else "cpu")

seed = 42
np.random.seed(seed)
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

## 01. data load

In [6]:
config_path = "./config.yaml"
with open(config_path, "r") as file:
    try:
        config = yaml.safe_load(file)
    except yaml.YAMLError as exc:
        print(exc)

In [7]:
# ----------------
# DataLoader
# ----------------
data_path = config["exp_params"]["data_path"]
vocab_path = config["exp_params"]["vocab_path"]
labels_list = ["조선일보", "동아일보", "경향신문", "한겨레"]
labels_dict = {label: idx for idx, label in enumerate(labels_list)}

with open(vocab_path, "rb") as f:
    word_index = pickle.load(f)


dataset = NewsDataset(data_path)

test_loader = DataLoader(
    dataset=dataset,
    batch_size=64,
    shuffle=False,
    collate_fn=partial(collate_fn, word_index=word_index, labels_dict=labels_dict),
)

## 02. Model load

In [8]:
ckpt_path = "../checkpoints/BilstmAttn_epoch=29_val_loss=0.05.ckpt"

checkpoint = torch.load(ckpt_path)
checkpoint["state_dict"] = OrderedDict(
    [(key.replace("model.", ""), val) for key, val in checkpoint["state_dict"].items()]
)

In [9]:
# vocab_size & num_class
config["model_params"]["vocab_size"] = len(word_index)
config["model_params"]["num_class"] = len(labels_list)

model = BiLSTMAttn(**config["model_params"]).to(DEVICE)
model.load_state_dict(checkpoint["state_dict"])
model.eval()

BiLSTMAttn(
  (embed): Embedding(30002, 128, padding_idx=0)
  (bilstm): LSTM(128, 256, num_layers=2, batch_first=True, dropout=0.3, bidirectional=True)
  (linear): Linear(in_features=512, out_features=4, bias=True)
  (dropout): Dropout(p=0.3, inplace=False)
)

## 03. Test

### 1) 각 키워드-언론사 별 단어의 attention score 계산

In [22]:
top_k = 20
index_word = {idx: word for word, idx in word_index.items()}
index_label = {idx: label for label, idx in labels_dict.items()}
result_dict = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))

for batch in tqdm(test_loader):
    sequences, labels, keywords = batch
    sequences = sequences.to(DEVICE)

    labels = [index_label[label] for label in labels.tolist()]
    _, attn_scores = model(sequences)

    for keyword, label, attn_score, sequence in zip(
        keywords, labels, attn_scores, sequences
    ):

        topk_attns, topk_idxs = torch.topk(attn_score, top_k)
        topk_attns = topk_attns.tolist()
        topk_seq = sequence[topk_idxs].tolist()

        result = [
            (index_word[seq], score)
            for seq, score in zip(topk_seq, topk_attns)
            if seq > 1
        ]

        for word, score in result:
            if len(word) > 1:
                result_dict[keyword][label][word].append(score)

100%|██████████| 263/263 [00:10<00:00, 25.59it/s]


In [23]:
media_keyword_dict = defaultdict(lambda: defaultdict(lambda: defaultdict(float)))
for keyword, media_dict in result_dict.items():
    for media, word_dict in media_dict.items():
        for word, vals in word_dict.items():
            media_keyword_dict[keyword][media][word] = np.sum(vals)

### 2) 각 키워드-언론사별 상관계수 구하기

In [100]:
from itertools import combinations

In [97]:
keywords = ["필리버스터", "탄핵", "드루킹", "남북회담", "조국"]
medias = ["조선일보", "동아일보", "경향신문", "한겨레"]
media_comb = list(combinations(medias, 2))
corr_dict = defaultdict(list)
for keyword in keywords:
    for m1, m2 in media_comb:
        m1_words = sorted(
            media_keyword_dict[keyword][m1].items(), key=lambda x: x[1], reverse=True
        )[:100]
        m2_words = sorted(
            media_keyword_dict[keyword][m2].items(), key=lambda x: x[1], reverse=True
        )[:100]
        df1 = pd.DataFrame(
            m1_words, columns=[f"{keyword}_{m1}_단어", f"{keyword}_{m1}_점수"]
        )
        df2 = pd.DataFrame(
            m2_words, columns=[f"{keyword}_{m2}_단어", f"{keyword}_{m2}_점수"]
        )
        df3 = df1.merge(
            df2,
            how="inner",
            left_on=f"{keyword}_{m1}_단어",
            right_on=f"{keyword}_{m2}_단어",
        )
        corr = np.corrcoef(
            df3[f"{keyword}_{m1}_점수"].values, df3[f"{keyword}_{m2}_점수"].values
        )[0, 1]
        corr_dict[keyword].append((f"{m1}-{m2}", corr))

In [105]:
corr_dict["조국"]

[('조선일보-동아일보', 0.7272041372627024),
 ('조선일보-경향신문', 0.46991132026884136),
 ('조선일보-한겨레', 0.34387486764558295),
 ('동아일보-경향신문', 0.5519848021533316),
 ('동아일보-한겨레', 0.4877613094702468),
 ('경향신문-한겨레', 0.7343701182189166)]

### 3) 각 키워드-언론사별 상관계수 가설검정

In [102]:
from scipy import stats

In [135]:
# N = 100
# pval_dict = defaultdict(list)
# rho_list = np.arange(0, 1, 0.01)
# for rho in rho_list:
#     for keyword, corrs in corr_dict.items():
#         for m1m2, corr in corrs:
#             t = (corr - rho) / np.sqrt((1 - corr ** 2) / (N - 2))
#             pval = stats.t.sf(t, N - 1) * 2
#             pval_dict[keyword].append((m1m2, rho, pval))

In [136]:
N = 100
pval_dict = defaultdict(lambda: defaultdict(list))
rho_list = np.arange(0, 1, 0.01)
for rho in rho_list:
    for keyword, corrs in corr_dict.items():
        for m1m2, corr in corrs:
            t = (corr - rho) / np.sqrt((1 - corr ** 2) / (N - 2))
            pval = stats.t.sf(t, N - 1) * 2
            pval_dict[keyword][m1m2].append((rho, pval))

In [183]:
pval_max_dict = defaultdict(list)
for keyword, mp_dict in pval_dict.items():
    for m1m2, pvals in mp_dict.items():
        pval_maxs = [(rho, pval) for rho, pval in pvals if pval <= 0.05]
        try:
            rho, p_max = max(pval_maxs, key=lambda x: x[1])
        except:
            rho, p_max = "x", "x"
        pval_max_dict[keyword].append((m1m2, rho, p_max))

In [184]:
pval_max_dict

defaultdict(list,
            {'필리버스터': [('조선일보-동아일보', 0.59, 0.04844332610540452),
              ('조선일보-경향신문', 0.35000000000000003, 0.04660600101154737),
              ('조선일보-한겨레', 0.67, 0.03659845970067276),
              ('동아일보-경향신문', 0.29, 0.046351449461411516),
              ('동아일보-한겨레', 0.44, 0.044139099181635486),
              ('경향신문-한겨레', 0.5700000000000001, 0.04358989977207472)],
             '탄핵': [('조선일보-동아일보', 0.49, 0.04289258511166552),
              ('조선일보-경향신문', 0.41000000000000003, 0.038930950932673056),
              ('조선일보-한겨레', 0.56, 0.04075798799515275),
              ('동아일보-경향신문', 0.25, 0.044476168506502864),
              ('동아일보-한겨레', 0.76, 0.04076509073796321),
              ('경향신문-한겨레', 0.85, 0.044404205792409515)],
             '드루킹': [('조선일보-동아일보', 0.46, 0.04780008651995186),
              ('조선일보-경향신문', 0.02, 0.04969068220697995),
              ('조선일보-한겨레', 0.38, 0.04202903889439298),
              ('동아일보-경향신문', 'x', 'x'),
              ('동아일보-한겨레', 0.44, 0.04