# BiLSTM + Attention Classifier

## imports

In [1]:
%load_ext lab_black

In [2]:
import sys

sys.path.append("..")

In [3]:
import pickle
import dill
import numpy as np
from functools import partial
from collections import Counter, defaultdict

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

from tqdm import tqdm
from models import BiLSTMAttn
from utils.types_ import *
from utils.data import NewsDataset, collate_fn

## 01. data load

In [4]:
data_path = "../data/tokenized/nouns_total_data.txt"
dataset = NewsDataset(data_path)

In [5]:
# dataset[0]

 ### 2) DataLoader

In [6]:
labels_list = ["조선일보", "동아일보", "경향신문", "한겨레"]
labels_dict = {label: idx for idx, label in enumerate(labels_list)}

with open("../data/vocab/word_index.pkl", "rb") as f:
    word_index = pickle.load(f)

In [7]:
dataloader = DataLoader(
    dataset=dataset,
    batch_size=64,
    shuffle=True,
    collate_fn=partial(collate_fn, word_index=word_index, labels_dict=labels_dict),
)

In [8]:
# for batch in dataloader:
#     sequences, labels, keywords = batch
#     break

## BiLSTM + Attention Model

In [9]:
# Device configuration
GPU_NUM = 1
DEVICE = torch.device(f"cuda:{GPU_NUM}" if torch.cuda.is_available() else "cpu")
DEVICE

device(type='cuda', index=1)

In [10]:
vocab_size = len(word_index)
num_class = 4

model = BiLSTMAttn(vocab_size, num_class).to(DEVICE)

In [11]:
for batch in dataloader:
    sequences, labels, keywords = batch
    break

sequences = sequences.to(DEVICE)
labels = labels.to(DEVICE)

In [32]:
preds = model(sequences)[0]

In [33]:
criterion = nn.CrossEntropyLoss()

In [34]:
loss = criterion(preds, labels)

In [35]:
# (p == labels).float().sum() / labels.size(0)

In [37]:
def accuracy(preds, labels):
    preds = torch.max(preds, 1)[1]
    corrects = (p == labels).float().sum()
    acc = corrects / labels.numel()
    return acc

In [38]:
accuracy(preds, labels)

tensor(0.4375, device='cuda:1')