In [None]:
from src.dataset import MakeDataset
from src.model import MakeEmbed
from torch.utils.data import DataLoader

# 데이터셋 만들기
dataset = MakeDataset()

# 임베딩 모델 불러오기
embed = MakeEmbed()
embed.load_word2vec()

entity_train_dataset, entity_test_dataset = dataset.make_entity_dataset(embed)

batch_size = 128

train_dataloader = DataLoader(entity_train_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(entity_test_dataset, batch_size=batch_size, shuffle=True)

In [None]:
import torch
from src.model import BiLSTM_CRF

# bilstm CNN 모델 만들기
weights = embed.word2vec.wv.vectors
weights = torch.FloatTensor(weights)

bilstm_crf_model = BiLSTM_CRF(weights, dataset.entity_label, 256, 128)
optimizer = torch.optim.Adam(bilstm_crf_model.parameters(), lr=0.001)

bilstm_crf_model.train()

In [None]:
from tqdm import tqdm
from tqdm import trange
import os
import torch.nn.functional as F

epoch = 15
prev_acc = 0
save_dir = "./nlp/pretrained/"
save_prefix = "cafe_entity_recog"

def save(model, save_dir, save_prefix, epoch):
    if not os.path.isdir(save_dir):
        os.makedirs(save_dir)
    save_prefix = os.path.join(save_dir, save_prefix)
    save_path = '{}_steps_{}.pt'.format(save_prefix, epoch)
    torch.save(model.state_dict(), save_path)

for i in range(epoch):
    steps = 0
    
    bilstm_crf_model.train() # 모델 학습 하겠다. (parameters가 수정됨)
    
    with tqdm(train_dataloader, unit="batch") as tepoch: # 진행상황 표시
        for data in tepoch:
            tepoch.set_description(f"Epoch {i}")
            x = data[0]
            y = data[1]
            length = data[2]
            
            logits = bilstm_crf_model.forward(x)
            
            
            # padding 된 부분을 마스킹하기 위한 코드
            mask = torch.where(x > 0, torch.tensor([1.]), torch.tensor([0.])).type(torch.uint8)
            
            loss = bilstm_crf_model.compute_loss(y, logits, mask)
            
            loss.backward()
            optimizer.step()

            tepoch.set_postfix(loss=loss.item())
            
    bilstm_crf_model.eval() # 모델 검증하겠다 (parameters 수정안됨)
    steps = 0
    accuracy_list = []
    with tqdm(test_dataloader, unit="batch") as tepoch:
        for data in tepoch:
            tepoch.set_description(f"Epoch {i}")
            x = data[0]
            y = data[1]
            length = data[2]
            mask = torch.where(x > 0, torch.tensor([1.]), torch.tensor([0.])).type(torch.uint8)
            logits = bilstm_crf_model.forward(x)

            predicts = bilstm_crf_model.decode(logits, mask)
            
            corrects = []
            
            for target, leng, predict in zip(y, length, predicts):
                corrects.append(target[:leng].tolist() == predict) 
                
            accuracy = 100.0 * sum(corrects)/len(corrects)
            accuracy_list.append(accuracy)
            
            loss = bilstm_crf_model.compute_loss(y, logits, mask)

            
            tepoch.set_postfix(loss=loss.item(), accuracy= sum(accuracy_list)/len(accuracy_list))
    
    # epoch 당 검증 셋의 정확도를 계산하고 이전 정확도 보다 높으면 저장     
    acc = sum(accuracy_list)/len(accuracy_list)
    if(acc>prev_acc):
        prev_acc = acc
        save(bilstm_crf_model, save_dir, save_prefix+"_"+str(round(acc, 3)), i)