In [1]:
!nvidia-smi

Sun Oct 30 21:17:23 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.141.03   Driver Version: 470.141.03   CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA GeForce ...  Off  | 00000000:01:00.0 Off |                  N/A |
|  0%   29C    P8     9W / 350W |     17MiB / 24265MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+---------------------------------------------------------------------------

## 필요 패키지 다운로드

## 모듈 import

In [8]:
import json
import os

import torch
import torch.nn as nn
from tqdm import trange
from transformers import XLMRobertaModel, AutoTokenizer
from transformers import ElectraModel, ElectraTokenizer
from torch.utils.data import DataLoader, TensorDataset
from transformers import get_linear_schedule_with_warmup
from transformers import AdamW
from datasets import load_metric
from sklearn.metrics import f1_score
import pandas as pd
import copy

## 전역 변수 설정
구글 드라이브 마운트 기준으로 설정되어 있음

In [36]:
PADDING_TOKEN = 1
S_OPEN_TOKEN = 0
S_CLOSE_TOKEN = 2

do_eval=True

category_extraction_model_path = './saved_model/category_extraction_emb/' 
polarity_classification_model_path = './saved_model/polarity_classification_emb/'

test_category_extraction_model_path = './saved_model/category_extraction_emb/saved_model_epoch_15.pt'
test_polarity_classification_model_path = './saved_model/polarity_classification_emb/saved_model_epoch_15.pt'

train_data_path = './data/nikluge-sa-2022-train.jsonl'
dev_data_path = './data/nikluge-sa-2022-dev.jsonl'
test_data_path = './data/nikluge-sa-2022-test.jsonl'

max_len = 256
batch_size = 8
base_model = 'monologg/koelectra-base-v3-discriminator'
learning_rate = 3e-6
eps = 1e-8
num_train_epochs = 20
classifier_hidden_size = 768
emb_classifier_hidden_size = 1280 # 기존 768에서 -> 1024 임베딩을 concat한 vector
classifier_dropout_prob = 0.1
lstm_hidden = 256
lstm_num_layer = 1
hidden_dropout_prob = 0.3
bilstm_flag = True

entity_property_pair = [
    '제품 전체#일반', '제품 전체#가격', '제품 전체#디자인', '제품 전체#품질', '제품 전체#편의성', '제품 전체#인지도', '제품 전체#다양성',
    '본품#일반', '본품#디자인', '본품#품질', '본품#편의성', '본품#다양성', '본품#가격', '본품#인지도',
    '패키지/구성품#일반', '패키지/구성품#디자인', '패키지/구성품#품질', '패키지/구성품#편의성', '패키지/구성품#다양성', '패키지/구성품#가격',
    '브랜드#일반', '브랜드#가격', '브랜드#디자인', '브랜드#품질', '브랜드#인지도' ]

entity_property_to_id = {entity_property_pair[i] : i for i in range(len(entity_property_pair))}
print(entity_property_to_id)

tf_id_to_name = ['True', 'False']
tf_name_to_id = {tf_id_to_name[i]: i for i in range(len(tf_id_to_name))}

polarity_id_to_name = ['positive', 'negative', 'neutral']
polarity_name_to_id = {polarity_id_to_name[i]: i for i in range(len(polarity_id_to_name))}

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

special_tokens_dict = {
    'additional_special_tokens': ['&name&', '&affiliation&', '&social-security-num&', '&tel-num&', '&card-num&', '&bank-account&', '&num&', '&online-account&']
}

entity_property_seq = [ i for i in range(len(entity_property_pair))]

{'제품 전체#일반': 0, '제품 전체#가격': 1, '제품 전체#디자인': 2, '제품 전체#품질': 3, '제품 전체#편의성': 4, '제품 전체#인지도': 5, '제품 전체#다양성': 6, '본품#일반': 7, '본품#디자인': 8, '본품#품질': 9, '본품#편의성': 10, '본품#다양성': 11, '본품#가격': 12, '본품#인지도': 13, '패키지/구성품#일반': 14, '패키지/구성품#디자인': 15, '패키지/구성품#품질': 16, '패키지/구성품#편의성': 17, '패키지/구성품#다양성': 18, '패키지/구성품#가격': 19, '브랜드#일반': 20, '브랜드#가격': 21, '브랜드#디자인': 22, '브랜드#품질': 23, '브랜드#인지도': 24}


## Json 파일 읽어오는 함수

In [10]:
def jsonload(fname, encoding="utf-8"):
    with open(fname, encoding=encoding) as f:
        j = json.load(f)

    return j


# json 개체를 파일이름으로 깔끔하게 저장
def jsondump(j, fname):
    with open(fname, "w", encoding="UTF8") as f:
        json.dump(j, f, ensure_ascii=False)

# jsonl 파일 읽어서 list에 저장
def jsonlload(fname, encoding="utf-8"):
    json_list = []
    with open(fname, encoding=encoding) as f:
        for line in f.readlines():
            json_list.append(json.loads(line))
    return json_list

jsonlload('./data/sample.jsonl')

[{'id': 'nikluge-sa-2022-train-00001',
  'sentence_form': '둘쨋날은 미친듯이 밟아봤더니 기어가 헛돌면서 틱틱 소리가 나서 경악.',
  'annotation': [['본품#품질', ['기어', 16, 18], 'negative']]},
 {'id': 'nikluge-sa-2022-train-00002',
  'sentence_form': '이거 뭐 삐꾸를 준 거 아냐 불안하고, 거금 투자한 게 왜 이래.. 싶어서 정이 확 떨어졌는데 산 곳 가져가서 확인하니 기어 텐션 문제라고 고장 아니래.',
  'annotation': [['본품#품질', ['기어 텐션', 67, 72], 'negative']]},
 {'id': 'nikluge-sa-2022-train-00003',
  'sentence_form': '간사하게도 그 이후에는 라이딩이 아주 즐거워져서 만족스럽게 탔다.',
  'annotation': [['제품 전체#일반', [None, 0, 0], 'positive']]},
 {'id': 'nikluge-sa-2022-train-00004',
  'sentence_form': '샥이 없는 모델이라 일반 도로에서 타면 노면의 진동 때문에 손목이 덜덜덜 떨리고 이가 부딪칠 지경인데 이마저도 며칠 타면서 익숙해지니 신경쓰이지 않게 됐다.',
  'annotation': [['제품 전체#일반', ['샥이 없는 모델', 0, 8], 'neutral']]},
 {'id': 'nikluge-sa-2022-train-00005',
  'sentence_form': '안장도 딱딱해서 엉덩이가 아팠는데 무시하고 타고 있다.',
  'annotation': [['본품#일반', ['안장', 0, 2], 'negative']]},
 {'id': 'nikluge-sa-2022-train-00006',
  'sentence_form': '지금 내 실력과 저질 체력으로는 이 정도 자전거도 되게 훌륭한 거라는..',
  'annotation'

## 모델 정의
xlm-roberta 모델을 기반으로 한 classification 모델 이용

In [23]:
class SimpleClassifier(nn.Module):

    def __init__(self, num_label):
        super().__init__()
        self.dense = nn.Linear(classifier_hidden_size, classifier_hidden_size)
        self.dropout = nn.Dropout(classifier_dropout_prob)
        self.output = nn.Linear(classifier_hidden_size, num_label)

    def forward(self, features):
        x = features[:, 0, :]
        x = self.dropout(x)
        x = self.dense(x)
        x = torch.tanh(x)
        x = self.dropout(x)
        x = self.output(x)
        return x

class EmbeddingClassifier(nn.Module):

    def __init__(self, num_label):
        super().__init__()
        self.dense = nn.Linear(emb_classifier_hidden_size, classifier_hidden_size)
        self.dropout = nn.Dropout(classifier_dropout_prob)
        self.output = nn.Linear(classifier_hidden_size, num_label)

    def forward(self, features):
        x = features[:, 0, :]
        x = self.dropout(x)
        x = self.dense(x)
        x = torch.tanh(x)
        x = self.dropout(x)
        x = self.output(x)
        return x

class CategoryClassification(nn.Module):
    def __init__(self, num_label, len_tokenizer, category_emb_size, category_size, num_layer, bilstm_flag):
        super().__init__()

        assert category_emb_size == lstm_hidden * 2, "Please set category-embedding-size to twice the lstm-hidden-size"

        self.num_label = num_label # 0, 1 -> 0이면 해당 카테고리는 있는 거임.
        self.electra = ElectraModel.from_pretrained(base_model)
        self.electra.resize_token_embeddings(len_tokenizer)

        self.n_hidden = lstm_hidden

        self.category_emb = nn.Embedding(category_size, category_emb_size, scale_grad_by_freq=True)

        self.num_layers = num_layer
        self.bidirectional = 2 if bilstm_flag else 1

        self.category_lstm_first = nn.LSTM(768, 256, bidirectional=True, batch_first=True)
        self.category_lstm_last = nn.LSTM(lstm_hidden * 4, self.n_hidden, num_layers=self.num_layers, batch_first=True, bidirectional=bilstm_flag)

        self.category_q_liner = nn.Linear(self.n_hidden * 2, self.n_hidden * 2)
        self.category_k_liner = nn.Linear(self.n_hidden * 2, self.n_hidden * 2)
        self.category_v_liner = nn.Linear(self.n_hidden * 2, self.n_hidden * 2)

        self.dropout = nn.Dropout(hidden_dropout_prob)

        self.softmax = nn.Softmax(dim=-1)

        self.labels_classifier = EmbeddingClassifier(self.num_label) # 입력 사이즈는 classifier_hidden_size = (768, 2) 0과 1 -> 


    def forward(self, input_ids=None, attention_mask=None, labels=None, category_label_seq_tensor=None, token_type_ids=None ):

        outputs = self.electra(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=None
        )

        # (batch_size, max_length, hidden_size)
        sequence_output = outputs[0] # (batch_size, 256, 768)
        # print("sequence_output size:", sequence_output.shape)

        category_embs = self.category_emb(category_label_seq_tensor)
        # print(category_label_seq_tensor)
        # print('category_embs size:', category_embs.shape)

        hidden = None
        scaler = self.n_hidden ** 0.5

        """
        category predict layer
        """
        category_lstm_outputs, hidden = self.category_lstm_first(sequence_output, hidden)
        category_lstm_outputs = self.dropout(category_lstm_outputs)

        # print('category_lstm_outputs size:',category_lstm_outputs.shape)

        category_q = self.category_q_liner(category_lstm_outputs)
        category_k = self.category_k_liner(category_embs)
        category_v = self.category_v_liner(category_embs)

        # print("category_q", category_q.shape)
        # print("category_k", category_k.shape)
        # print("category_v", category_v.shape)

        category_attention_score = category_q.matmul(category_k.permute(0, 2, 1)) / scaler
        category_attention_align = self.softmax(category_attention_score)
        # print("category_attention_score", category_attention_score)

        category_attention_output = category_attention_align.matmul(category_v)
        category_attention_output = self.dropout(category_attention_output)
        # print("category_attention_output", category_attention_output.shape)


        category_lstm_outputs = torch.cat([category_lstm_outputs, category_attention_output], dim=-1)
        # print("category_lstm_outputs:", category_lstm_outputs.shape)

        category_lstm_outputs, hidden = self.category_lstm_last(category_lstm_outputs, hidden)
        category_lstm_outputs = self.dropout(category_lstm_outputs) # 요걸 다음층에 입력에 넣어야 하나..?

        # print("category_lstm_outputs", category_lstm_outputs.shape)

        category_q = self.category_q_liner(category_lstm_outputs)
        category_k = self.category_k_liner(category_embs)

        # print("category_q_last", category_q.shape)
        # print("category_k_last", category_k.shape)

        category_attention_score = category_q.matmul(category_k.permute(0, 2, 1)) / scaler
        category_attention_score = self.dropout(category_attention_score)

        final_category_attention_score = category_attention_score[:, 0, :]

        category_attention_score = category_attention_score.matmul(category_embs) # [batch_size, max_length, max_length]
        category_attention_score = torch.cat([sequence_output, category_attention_score], dim=-1)

        # print(category_attention_score.shape)
      
        logits = self.labels_classifier(category_attention_score)

        loss = None

        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.num_label),
                                                labels.view(-1))

        return loss, logits

class RoBertaBaseClassifier(nn.Module):
    def __init__(self, num_label, len_tokenizer):
        super(RoBertaBaseClassifier, self).__init__()

        self.num_label = num_label
        self.electra = ElectraModel.from_pretrained(base_model)
        self.electra.resize_token_embeddings(len_tokenizer)

        self.labels_classifier = SimpleClassifier(self.num_label)

    def forward(self, input_ids, attention_mask, labels=None):
        outputs = self.electra(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=None
        )

        sequence_output = outputs[0]
        # print(sequence_output.shape)
        logits = self.labels_classifier(sequence_output)
        # print(logits.shape)

        loss = None

        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.num_label),
                                                labels.view(-1))

        return loss, logits

## 데이터 파싱 및 토크나이저 정의

In [24]:
def tokenize_and_align_labels(tokenizer, form, annotations, max_len):

    entity_property_data_dict = {
        'input_ids': [],
        'attention_mask': [],
        'label': [],
        'category': []
    }
    polarity_data_dict = {
        'input_ids': [],
        'attention_mask': [],
        'label': []
    }

    for pair in entity_property_pair:
        isPairInOpinion = False
        if pd.isna(form):
            break
        tokenized_data = tokenizer(form, pair, padding='max_length', max_length=max_len, truncation=True)
        for annotation in annotations:
            entity_property = annotation[0]
            polarity = annotation[2]

            if polarity == '------------':
                continue

            category_seq = [ i for i in range(len(entity_property_pair))]    

            if entity_property == pair:
                entity_property_data_dict['input_ids'].append(tokenized_data['input_ids'])
                entity_property_data_dict['attention_mask'].append(tokenized_data['attention_mask'])
                entity_property_data_dict['label'].append(tf_name_to_id['True'])
                entity_property_data_dict['category'].append(category_seq)
                # print(entity_property_data_dict)

                polarity_data_dict['input_ids'].append(tokenized_data['input_ids'])
                polarity_data_dict['attention_mask'].append(tokenized_data['attention_mask'])
                polarity_data_dict['label'].append(polarity_name_to_id[polarity])

                isPairInOpinion = True
                break

        if isPairInOpinion is False:
            entity_property_data_dict['input_ids'].append(tokenized_data['input_ids'])
            entity_property_data_dict['attention_mask'].append(tokenized_data['attention_mask'])
            entity_property_data_dict['label'].append(tf_name_to_id['False'])
            entity_property_data_dict['category'].append(category_seq) # 개체 없음(즉, 틀린 개체#속성 pair의 경우)

    return entity_property_data_dict, polarity_data_dict


def get_dataset(raw_data, tokenizer, max_len):
    input_ids_list = []
    attention_mask_list = []
    token_labels_list = []
    category_seq_list = []

    polarity_input_ids_list = []
    polarity_attention_mask_list = []
    polarity_token_labels_list = []

    for utterance in raw_data:
        entity_property_data_dict, polarity_data_dict = tokenize_and_align_labels(tokenizer, utterance['sentence_form'], utterance['annotation'], max_len)
        input_ids_list.extend(entity_property_data_dict['input_ids'])
        attention_mask_list.extend(entity_property_data_dict['attention_mask'])
        token_labels_list.extend(entity_property_data_dict['label'])
        category_seq_list.extend(entity_property_data_dict['category'])
       
        polarity_input_ids_list.extend(polarity_data_dict['input_ids'])
        polarity_attention_mask_list.extend(polarity_data_dict['attention_mask'])
        polarity_token_labels_list.extend(polarity_data_dict['label'])

    return TensorDataset(torch.tensor(input_ids_list), torch.tensor(attention_mask_list),
                         torch.tensor(token_labels_list), torch.tensor(category_seq_list)), TensorDataset(torch.tensor(polarity_input_ids_list), torch.tensor(polarity_attention_mask_list),
                         torch.tensor(polarity_token_labels_list))

## 모델 학습

In [25]:
def evaluation(y_true, y_pred, label_len):
    count_list = [0]*label_len
    hit_list = [0]*label_len
    for i in range(len(y_true)):
        count_list[y_true[i]] += 1
        if y_true[i] == y_pred[i]:
            hit_list[y_true[i]] += 1
    acc_list = []

    for i in range(label_len):
        acc_list.append(hit_list[i]/count_list[i])

    print(count_list)
    print(hit_list)
    print(acc_list)
    print('accuracy: ', (sum(hit_list) / sum(count_list)))
    print('macro_accuracy: ', sum(acc_list) / 3)
    # print(y_true)

    y_true = list(map(int, y_true))
    y_pred = list(map(int, y_pred))

    print('f1_score: ', f1_score(y_true, y_pred, average=None))
    print('f1_score_micro: ', f1_score(y_true, y_pred, average='micro'))
    print('f1_score_macro: ', f1_score(y_true, y_pred, average='macro'))

def train_sentiment_analysis():

    print('train_sentiment_analysis')
    print('category_extraction model would be saved at ', category_extraction_model_path)
    print('polarity model would be saved at ', polarity_classification_model_path)

    print('loading train data...')
    train_data = jsonlload(train_data_path)
    dev_data = jsonlload(dev_data_path)

    print('tokenizing train data...')
    tokenizer = ElectraTokenizer.from_pretrained(base_model)
    num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)
    print('We have added', num_added_toks, 'tokens')

    print('making dataset...')
    entity_property_train_data, polarity_train_data = get_dataset(train_data, tokenizer, max_len)
    entity_property_dev_data, polarity_dev_data = get_dataset(dev_data, tokenizer, max_len)

    print('making dataloader...')
    entity_property_train_dataloader = DataLoader(entity_property_train_data, shuffle=True,
                                  batch_size=batch_size)
    entity_property_dev_dataloader = DataLoader(entity_property_dev_data, shuffle=True,
                                batch_size=batch_size)

    polarity_train_dataloader = DataLoader(polarity_train_data, shuffle=True,
                                                  batch_size=batch_size)
    polarity_dev_dataloader = DataLoader(polarity_dev_data, shuffle=True,
                                                batch_size=batch_size)

    print('loading model...')
    # entity_property_model = RoBertaBaseClassifier(len(tf_id_to_name), len(tokenizer))
    entity_property_model = CategoryClassification(len(tf_id_to_name), len(tokenizer), lstm_hidden * 2, len(entity_property_pair), lstm_num_layer, bilstm_flag)
    entity_property_model.to(device)

    polarity_model = RoBertaBaseClassifier(len(polarity_id_to_name), len(tokenizer))
    polarity_model.to(device)

    print('end loading')

    # entity_property_model_optimizer_setting
    FULL_FINETUNING = True
    if FULL_FINETUNING:
        entity_property_param_optimizer = list(entity_property_model.named_parameters())
        no_decay = ['bias', 'gamma', 'beta']
        entity_property_optimizer_grouped_parameters = [
            {'params': [p for n, p in entity_property_param_optimizer if not any(nd in n for nd in no_decay)],
             'weight_decay_rate': 0.01},
            {'params': [p for n, p in entity_property_param_optimizer if any(nd in n for nd in no_decay)],
             'weight_decay_rate': 0.0}
        ]
    else:
        entity_property_param_optimizer = list(entity_property_model.classifier.named_parameters())
        entity_property_optimizer_grouped_parameters = [{"params": [p for n, p in entity_property_param_optimizer]}]

    entity_property_optimizer = AdamW(
        entity_property_optimizer_grouped_parameters,
        lr=learning_rate,
        eps=eps
    )
    epochs = num_train_epochs
    max_grad_norm = 1.0
    total_steps = epochs * len(entity_property_train_dataloader)
    print("total_steps : ", total_steps)

    entity_property_scheduler = get_linear_schedule_with_warmup(
        entity_property_optimizer,
        num_warmup_steps=0,
        num_training_steps=total_steps
    )

    # polarity_model_optimizer_setting
    if FULL_FINETUNING:
        polarity_param_optimizer = list(polarity_model.named_parameters())
        no_decay = ['bias', 'gamma', 'beta']
        polarity_optimizer_grouped_parameters = [
            {'params': [p for n, p in polarity_param_optimizer if not any(nd in n for nd in no_decay)],
             'weight_decay_rate': 0.01},
            {'params': [p for n, p in polarity_param_optimizer if any(nd in n for nd in no_decay)],
             'weight_decay_rate': 0.0}
        ]
    else:
        polarity_param_optimizer = list(polarity_model.classifier.named_parameters())
        polarity_optimizer_grouped_parameters = [{"params": [p for n, p in polarity_param_optimizer]}]

    polarity_optimizer = AdamW(
        polarity_optimizer_grouped_parameters,
        lr=learning_rate,
        eps=eps
    )
    epochs = num_train_epochs
    max_grad_norm = 1.0
    total_steps = epochs * len(polarity_train_dataloader)

    polarity_scheduler = get_linear_schedule_with_warmup(
        polarity_optimizer,
        num_warmup_steps=0,
        num_training_steps=total_steps
    )


    epoch_step = 0
    print("학습을 시작합니다...")
    for _ in trange(epochs, desc="Epoch"):
        entity_property_model.train()
        epoch_step += 1

        # entity_property train
        entity_property_total_loss = 0

        for step, batch in enumerate(entity_property_train_dataloader):
            batch = tuple(t.to(device) for t in batch)
            b_input_ids, b_input_mask, b_labels, b_category_seq = batch
            # print("b_input_ids:", b_input_ids)
            # print("b_input_mask:", b_input_mask)
            # print("b_labels:", b_labels)
            # print("")

            entity_property_model.zero_grad()

            loss, _ = entity_property_model(b_input_ids, b_input_mask, b_labels, b_category_seq)

            loss.backward()

            entity_property_total_loss += loss.item()
            # print('batch_loss: ', loss.item())

            torch.nn.utils.clip_grad_norm_(parameters=entity_property_model.parameters(), max_norm=max_grad_norm)
            entity_property_optimizer.step()
            entity_property_scheduler.step()

        avg_train_loss = entity_property_total_loss / len(entity_property_train_dataloader)
        print("Entity_Property_Epoch: ", epoch_step)
        print("Average train loss: {}".format(avg_train_loss))

        model_saved_path = category_extraction_model_path + 'saved_model_epoch_' + str(epoch_step) + '.pt'
        torch.save(entity_property_model.state_dict(), model_saved_path)

        if do_eval:
            entity_property_model.eval()

            pred_list = []
            label_list = []

            for batch in entity_property_dev_dataloader:
                batch = tuple(t.to(device) for t in batch)
                b_input_ids, b_input_mask, b_labels, b_category_seq = batch

                with torch.no_grad():
                    loss, logits = entity_property_model(b_input_ids, b_input_mask, b_labels, b_category_seq)

                predictions = torch.argmax(logits, dim=-1)
                pred_list.extend(predictions)
                label_list.extend(b_labels)

            evaluation(label_list, pred_list, len(tf_id_to_name))


        # polarity train
        polarity_total_loss = 0
        polarity_model.train()

        for step, batch in enumerate(polarity_train_dataloader):
            batch = tuple(t.to(device) for t in batch)
            b_input_ids, b_input_mask, b_labels = batch

            polarity_model.zero_grad()

            loss, _ = polarity_model(b_input_ids, b_input_mask, b_labels)

            loss.backward()

            polarity_total_loss += loss.item()
            # print('batch_loss: ', loss.item())

            torch.nn.utils.clip_grad_norm_(parameters=polarity_model.parameters(), max_norm=max_grad_norm)
            polarity_optimizer.step()
            polarity_scheduler.step()

        avg_train_loss = polarity_total_loss / len(polarity_train_dataloader)
        print("Entity_Property_Epoch: ", epoch_step)
        print("Average train loss: {}".format(avg_train_loss))

        model_saved_path = polarity_classification_model_path + 'saved_model_epoch_' + str(epoch_step) + '.pt'
        torch.save(polarity_model.state_dict(), model_saved_path)

        if do_eval:
            polarity_model.eval()

            pred_list = []
            label_list = []

            for batch in polarity_dev_dataloader:
                batch = tuple(t.to(device) for t in batch)
                b_input_ids, b_input_mask, b_labels = batch

                with torch.no_grad():
                    loss, logits = polarity_model(b_input_ids, b_input_mask, b_labels)

                predictions = torch.argmax(logits, dim=-1)
                pred_list.extend(predictions)
                label_list.extend(b_labels)

            evaluation(label_list, pred_list, len(polarity_id_to_name))

    print("training is done")

In [26]:
train_sentiment_analysis()

train_sentiment_analysis
category_extraction model would be saved at  ./saved_model/category_extraction_emb/
polarity model would be saved at  ./saved_model/polarity_classification_emb/
loading train data...
tokenizing train data...
We have added 8 tokens
making dataset...
making dataloader...
loading model...


Some weights of the model checkpoint at monologg/koelectra-base-v3-discriminator were not used when initializing ElectraModel: ['discriminator_predictions.dense.bias', 'discriminator_predictions.dense_prediction.weight', 'discriminator_predictions.dense.weight', 'discriminator_predictions.dense_prediction.bias']
- This IS expected if you are initializing ElectraModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ElectraModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at monologg/koelectra-base-v3-discriminator were not used when initializing ElectraModel: ['discriminator_predictions.dense.bias', 'discriminator_predictions.dense_prediction.weight', 'discr

end loading
total_steps :  187580
학습을 시작합니다...


Epoch:   0%|          | 0/20 [00:00<?, ?it/s]

Entity_Property_Epoch:  1
Average train loss: 0.14129454995150015
[3003, 66847]
[1536, 66147]
[0.5114885114885115, 0.9895283258784987]
accuracy:  0.9689763779527559
macro_accuracy:  0.5003389457890034
f1_score:  [0.58637144 0.9838838 ]
f1_score_micro:  0.9689763779527559
f1_score_macro:  0.7851276238352898
Entity_Property_Epoch:  1
Average train loss: 0.3314095962001011


Epoch:   5%|▌         | 1/20 [21:27<6:47:41, 1287.46s/it]

[2921, 28, 54]
[2921, 0, 0]
[1.0, 0.0, 0.0]
accuracy:  0.9726939726939727
macro_accuracy:  0.3333333333333333
f1_score:  [0.986158 0.       0.      ]
f1_score_micro:  0.9726939726939727
f1_score_macro:  0.3287193337834797
Entity_Property_Epoch:  2
Average train loss: 0.10841248687699695
[3003, 66847]
[1924, 65971]
[0.6406926406926406, 0.986895447813664]
accuracy:  0.9720114531138153
macro_accuracy:  0.542529362835435
f1_score:  [0.66310529 0.98539922]
f1_score_micro:  0.9720114531138153
f1_score_macro:  0.8242522575721528
Entity_Property_Epoch:  2
Average train loss: 0.1924014145741239


Epoch:  10%|█         | 2/20 [42:55<6:26:22, 1287.92s/it]

[2921, 28, 54]
[2921, 0, 0]
[1.0, 0.0, 0.0]
accuracy:  0.9726939726939727
macro_accuracy:  0.3333333333333333
f1_score:  [0.986158 0.       0.      ]
f1_score_micro:  0.9726939726939727
f1_score_macro:  0.3287193337834797
Entity_Property_Epoch:  3
Average train loss: 0.09367765388483991
[3003, 66847]
[1815, 66280]
[0.6043956043956044, 0.9915179439615839]
accuracy:  0.9748747315676449
macro_accuracy:  0.5319711827857294
f1_score:  [0.67409471 0.9869337 ]
f1_score_micro:  0.9748747315676449
f1_score_macro:  0.8305142040750048
Entity_Property_Epoch:  3
Average train loss: 0.1602341804234311


Epoch:  15%|█▌        | 3/20 [1:04:25<6:05:09, 1288.80s/it]

[2921, 28, 54]
[2915, 13, 0]
[0.9979459089352961, 0.4642857142857143, 0.0]
accuracy:  0.975024975024975
macro_accuracy:  0.48741054107367016
f1_score:  [0.98796814 0.50980392 0.        ]
f1_score_micro:  0.975024975024975
f1_score_macro:  0.4992573541872265
Entity_Property_Epoch:  4
Average train loss: 0.08384895147162572
[3003, 66847]
[1948, 66111]
[0.6486846486846487, 0.9889897826379643]
accuracy:  0.9743593414459556
macro_accuracy:  0.5458914771075377
f1_score:  [0.68507122 0.98663562]
f1_score_micro:  0.9743593414459556
f1_score_macro:  0.8358534199769672
Entity_Property_Epoch:  4
Average train loss: 0.14120191658148543


Epoch:  20%|██        | 4/20 [1:25:52<5:43:29, 1288.10s/it]

[2921, 28, 54]
[2904, 20, 0]
[0.9941800753166724, 0.7142857142857143, 0.0]
accuracy:  0.9736929736929737
macro_accuracy:  0.5694885965341289
f1_score:  [0.9877551  0.55555556 0.        ]
f1_score_micro:  0.9736929736929737
f1_score_macro:  0.5144368858654573
Entity_Property_Epoch:  5
Average train loss: 0.07254611953779325
[3003, 66847]
[1903, 66185]
[0.6336996336996337, 0.9900967881879515]
accuracy:  0.9747745168217609
macro_accuracy:  0.5412654739625284
f1_score:  [0.68354885 0.98686369]
f1_score_micro:  0.9747745168217609
f1_score_macro:  0.8352062685462356
Entity_Property_Epoch:  5
Average train loss: 0.13156062470981852


Epoch:  25%|██▌       | 5/20 [1:47:19<5:21:53, 1287.54s/it]

[2921, 28, 54]
[2905, 21, 0]
[0.9945224238274564, 0.75, 0.0]
accuracy:  0.9743589743589743
macro_accuracy:  0.5815074746091521
f1_score:  [0.98792722 0.5915493  0.        ]
f1_score_micro:  0.9743589743589743
f1_score_macro:  0.5264921730119992
Entity_Property_Epoch:  6
Average train loss: 0.061765991853791884
[3003, 66847]
[2013, 66007]
[0.6703296703296703, 0.9874339910541984]
accuracy:  0.9738010021474588
macro_accuracy:  0.5525878871279563
f1_score:  [0.6875     0.98632737]
f1_score_micro:  0.9738010021474588
f1_score_macro:  0.8369136830937509
Entity_Property_Epoch:  6
Average train loss: 0.12141880097682588


Epoch:  30%|███       | 6/20 [2:08:44<5:00:15, 1286.81s/it]

[2921, 28, 54]
[2900, 21, 0]
[0.9928106812735364, 0.75, 0.0]
accuracy:  0.9726939726939727
macro_accuracy:  0.5809368937578455
f1_score:  [0.98757024 0.53164557 0.        ]
f1_score_micro:  0.9726939726939727
f1_score_macro:  0.5064052687655228
Entity_Property_Epoch:  7
Average train loss: 0.05444903843851038
[3003, 66847]
[2039, 66019]
[0.678987678987679, 0.9876135054677099]
accuracy:  0.9743450250536865
macro_accuracy:  0.5555337281517962
f1_score:  [0.69471891 0.98660988]
f1_score_micro:  0.9743450250536865
f1_score_macro:  0.8406643939570415
Entity_Property_Epoch:  7
Average train loss: 0.119792598002241


Epoch:  35%|███▌      | 7/20 [2:30:11<4:38:49, 1286.85s/it]

[2921, 28, 54]
[2894, 21, 3]
[0.9907565902088326, 0.75, 0.05555555555555555]
accuracy:  0.9716949716949717
macro_accuracy:  0.5987707152547961
f1_score:  [0.98703956 0.53164557 0.0952381 ]
f1_score_micro:  0.9716949716949717
f1_score_macro:  0.5379744094320916
Entity_Property_Epoch:  8
Average train loss: 0.04682100139032272
[3003, 66847]
[1996, 66102]
[0.6646686646686647, 0.9888551468278307]
accuracy:  0.9749176807444524
macro_accuracy:  0.5511746038321651
f1_score:  [0.69498607 0.98692108]
f1_score_micro:  0.9749176807444524
f1_score_macro:  0.8409535754932543
Entity_Property_Epoch:  8
Average train loss: 0.10903113144042437


Epoch:  40%|████      | 8/20 [2:51:39<4:17:25, 1287.15s/it]

[2921, 28, 54]
[2890, 21, 6]
[0.9893871961656967, 0.75, 0.1111111111111111]
accuracy:  0.9713619713619713
macro_accuracy:  0.6168327690922694
f1_score:  [0.98752776 0.51219512 0.16901408]
f1_score_micro:  0.9713619713619713
f1_score_macro:  0.5562456566661089
Entity_Property_Epoch:  9
Average train loss: 0.04198846133252614
[3003, 66847]
[1912, 66170]
[0.6366966366966367, 0.9898723951710623]
accuracy:  0.974688618468146
macro_accuracy:  0.5421896772892331
f1_score:  [0.68383405 0.9868166 ]
f1_score_micro:  0.974688618468146
f1_score_macro:  0.835325322110299
Entity_Property_Epoch:  9
Average train loss: 0.09356571434269427


Epoch:  45%|████▌     | 9/20 [3:13:05<3:55:54, 1286.79s/it]

[2921, 28, 54]
[2904, 17, 2]
[0.9941800753166724, 0.6071428571428571, 0.037037037037037035]
accuracy:  0.9733599733599734
macro_accuracy:  0.5461199898321888
f1_score:  [0.98792312 0.64150943 0.05405405]
f1_score_micro:  0.9733599733599734
f1_score_macro:  0.5611622013975128
Entity_Property_Epoch:  10
Average train loss: 0.03656304037628426
[3003, 66847]
[1963, 66121]
[0.6536796536796536, 0.9891393779825571]
accuracy:  0.9747172512526843
macro_accuracy:  0.5476063438874036
f1_score:  [0.68973999 0.98682168]
f1_score_micro:  0.9747172512526843
f1_score_macro:  0.8382808341164054
Entity_Property_Epoch:  10
Average train loss: 0.08324819936446147


Epoch:  50%|█████     | 10/20 [3:34:33<3:34:32, 1287.22s/it]

[2921, 28, 54]
[2901, 17, 8]
[0.9931530297843204, 0.6071428571428571, 0.14814814814814814]
accuracy:  0.9743589743589743
macro_accuracy:  0.5828146783584419
f1_score:  [0.98824732 0.65384615 0.19277108]
f1_score_micro:  0.9743589743589743
f1_score_macro:  0.6116215185019785
Entity_Property_Epoch:  11
Average train loss: 0.030410986410172338
[3003, 66847]
[1994, 65996]
[0.6640026640026641, 0.9872694361751462]
accuracy:  0.9733715103793844
macro_accuracy:  0.5504240333926034
f1_score:  [0.68194254 0.98610406]
f1_score_micro:  0.9733715103793844
f1_score_macro:  0.8340232998424092
Entity_Property_Epoch:  11
Average train loss: 0.0732990476526902


Epoch:  55%|█████▌    | 11/20 [3:55:59<3:13:02, 1286.97s/it]

[2921, 28, 54]
[2899, 17, 9]
[0.9924683327627525, 0.6071428571428571, 0.16666666666666666]
accuracy:  0.974025974025974
macro_accuracy:  0.5887592855240921
f1_score:  [0.9882393  0.64150943 0.20930233]
f1_score_micro:  0.974025974025974
f1_score_macro:  0.6130170213762087
Entity_Property_Epoch:  12
Average train loss: 0.02808201900538645
[3003, 66847]
[2021, 65963]
[0.672993672993673, 0.9867757715379898]
accuracy:  0.9732856120257695
macro_accuracy:  0.5532564815105543
f1_score:  [0.68415708 0.98605298]
f1_score_micro:  0.9732856120257695
f1_score_macro:  0.8351050264544266
Entity_Property_Epoch:  12
Average train loss: 0.06868988945847378


Epoch:  60%|██████    | 12/20 [4:17:28<2:51:38, 1287.34s/it]

[2921, 28, 54]
[2859, 18, 18]
[0.9787743923313934, 0.6428571428571429, 0.3333333333333333]
accuracy:  0.964035964035964
macro_accuracy:  0.6516549561739565
f1_score:  [0.98314993 0.66666667 0.26470588]
f1_score_micro:  0.964035964035964
f1_score_macro:  0.6381741600812723
Entity_Property_Epoch:  13
Average train loss: 0.02327567319471759
[3003, 66847]
[1981, 66002]
[0.6596736596736597, 0.9873591933819019]
accuracy:  0.9732712956335003
macro_accuracy:  0.5490109510185205
f1_score:  [0.67970492 0.98605374]
f1_score_micro:  0.9732712956335003
f1_score_macro:  0.8328793309789391
Entity_Property_Epoch:  13
Average train loss: 0.05718143567602965


Epoch:  65%|██████▌   | 13/20 [4:38:57<2:30:15, 1287.94s/it]

[2921, 28, 54]
[2862, 16, 16]
[0.9798014378637453, 0.5714285714285714, 0.2962962962962963]
accuracy:  0.9637029637029637
macro_accuracy:  0.6158421018628709
f1_score:  [0.98299845 0.65306122 0.23880597]
f1_score_micro:  0.9637029637029637
f1_score_macro:  0.6249552163479984
Entity_Property_Epoch:  14
Average train loss: 0.021142978826761888
[3003, 66847]
[2043, 65926]
[0.6803196803196803, 0.9862222687629961]
accuracy:  0.9730708661417323
macro_accuracy:  0.5555139830275588
f1_score:  [0.68476621 0.98593466]
f1_score_micro:  0.9730708661417323
f1_score_macro:  0.8353504375160343
Entity_Property_Epoch:  14
Average train loss: 0.04661028659698786


Epoch:  70%|███████   | 14/20 [5:00:24<2:08:46, 1287.72s/it]

[2921, 28, 54]
[2896, 14, 9]
[0.9914412872304006, 0.5, 0.16666666666666666]
accuracy:  0.972027972027972
macro_accuracy:  0.5527026512990224
f1_score:  [0.98721664 0.62222222 0.19148936]
f1_score_micro:  0.972027972027972
f1_score_macro:  0.6003094064475973
Entity_Property_Epoch:  15
Average train loss: 0.017704774665180695
[3003, 66847]
[2005, 66023]
[0.6676656676656677, 0.987673343605547]
accuracy:  0.973915533285612
macro_accuracy:  0.5517796704237382
f1_score:  [0.68758573 0.98638958]
f1_score_micro:  0.973915533285612
f1_score_macro:  0.8369876558375401
Entity_Property_Epoch:  15
Average train loss: 0.046919974527409064


Epoch:  75%|███████▌  | 15/20 [5:21:50<1:47:16, 1287.33s/it]

[2921, 28, 54]
[2890, 15, 13]
[0.9893871961656967, 0.5357142857142857, 0.24074074074074073]
accuracy:  0.9716949716949717
macro_accuracy:  0.5886140742069077
f1_score:  [0.98735907 0.6122449  0.25242718]
f1_score_micro:  0.9716949716949717
f1_score_macro:  0.6173437177153591
Entity_Property_Epoch:  16
Average train loss: 0.016516461799294505
[3003, 66847]
[1954, 66110]
[0.6506826506826506, 0.988974823103505]
accuracy:  0.9744309234073013
macro_accuracy:  0.546552491262052
f1_score:  [0.68633649 0.98667224]
f1_score_micro:  0.9744309234073013
f1_score_macro:  0.8365043665560776
Entity_Property_Epoch:  16
Average train loss: 0.044081689404847565


Epoch:  80%|████████  | 16/20 [5:43:16<1:25:47, 1286.90s/it]

[2921, 28, 54]
[2859, 15, 14]
[0.9787743923313934, 0.5357142857142857, 0.25925925925925924]
accuracy:  0.9617049617049617
macro_accuracy:  0.5912493124349795
f1_score:  [0.98213672 0.6        0.20895522]
f1_score_micro:  0.9617049617049617
f1_score_macro:  0.5970306488854252
Entity_Property_Epoch:  17
Average train loss: 0.01439175738069661
[3003, 66847]
[1958, 66127]
[0.652014652014652, 0.9892291351893129]
accuracy:  0.9747315676449535
macro_accuracy:  0.5470812624013216
f1_score:  [0.68931526 0.98683023]
f1_score_micro:  0.9747315676449535
f1_score_macro:  0.8380727434813364
Entity_Property_Epoch:  17
Average train loss: 0.03664976290179766


Epoch:  85%|████████▌ | 17/20 [6:04:44<1:04:21, 1287.03s/it]

[2921, 28, 54]
[2865, 15, 14]
[0.9808284833960972, 0.5357142857142857, 0.25925925925925924]
accuracy:  0.9637029637029637
macro_accuracy:  0.5919340094565474
f1_score:  [0.98318463 0.6        0.21875   ]
f1_score_micro:  0.9637029637029637
f1_score_macro:  0.6006448753145733
Entity_Property_Epoch:  18
Average train loss: 0.013640146190348354
[3003, 66847]
[1979, 66081]
[0.659007659007659, 0.9885409966041857]
accuracy:  0.9743736578382247
macro_accuracy:  0.5491828852039483
f1_score:  [0.68858733 0.986637  ]
f1_score_micro:  0.9743736578382246
f1_score_macro:  0.8376121695125848
Entity_Property_Epoch:  18
Average train loss: 0.03900131058209808


Epoch:  90%|█████████ | 18/20 [6:26:10<42:53, 1286.83s/it]  

[2921, 28, 54]
[2842, 15, 18]
[0.9729544676480657, 0.5357142857142857, 0.3333333333333333]
accuracy:  0.9573759573759574
macro_accuracy:  0.6140006955652282
f1_score:  [0.98033805 0.58823529 0.22929936]
f1_score_micro:  0.9573759573759574
f1_score_macro:  0.5992909015925312
Entity_Property_Epoch:  19
Average train loss: 0.012146625705917173
[3003, 66847]
[1991, 66069]
[0.663003663003663, 0.9883614821906742]
accuracy:  0.9743736578382247
macro_accuracy:  0.5504550483981124
f1_score:  [0.68988219 0.98663461]
f1_score_micro:  0.9743736578382246
f1_score_macro:  0.8382583997615956
Entity_Property_Epoch:  19
Average train loss: 0.030484878067072715


Epoch:  95%|█████████▌| 19/20 [6:47:37<21:26, 1286.71s/it]

[2921, 28, 54]
[2856, 15, 17]
[0.9777473467990414, 0.5357142857142857, 0.3148148148148148]
accuracy:  0.9617049617049617
macro_accuracy:  0.609425482442714
f1_score:  [0.98245614 0.58823529 0.24113475]
f1_score_micro:  0.9617049617049617
f1_score_macro:  0.6039420620805247
Entity_Property_Epoch:  20
Average train loss: 0.011062437473787084
[3003, 66847]
[2021, 66015]
[0.672993672993673, 0.9875536673298727]
accuracy:  0.9740300644237652
macro_accuracy:  0.5535157801078485
f1_score:  [0.69023224 0.98644691]
f1_score_micro:  0.9740300644237652
f1_score_macro:  0.838339574389106
Entity_Property_Epoch:  20
Average train loss: 0.033512031805148584


Epoch: 100%|██████████| 20/20 [7:09:03<00:00, 1287.15s/it]

[2921, 28, 54]
[2863, 15, 16]
[0.9801437863745293, 0.5357142857142857, 0.2962962962962963]
accuracy:  0.9637029637029637
macro_accuracy:  0.6040514561283704
f1_score:  [0.98351082 0.6        0.23880597]
f1_score_micro:  0.9637029637029637
f1_score_macro:  0.6074389303909856
training is done





## 모델 평가

In [30]:
def predict_from_korean_form(tokenizer, ce_model, pc_model, data):

    ce_model.to(device)
    ce_model.eval()
    for sentence in data:
        form = sentence['sentence_form']
        sentence['annotation'] = []
        if type(form) != str:
            print("form type is arong: ", form)
            continue
        for pair in entity_property_pair:
            
            tokenized_data = tokenizer(form, pair, padding='max_length', max_length=256, truncation=True)

            input_ids = torch.tensor([tokenized_data['input_ids']]).to(device)
            attention_mask = torch.tensor([tokenized_data['attention_mask']]).to(device)
            category_seq = [ i for i in range(len(entity_property_pair))] 
            category_seq = torch.tensor([category_seq]).to(device)

            with torch.no_grad():
                _, ce_logits = ce_model(input_ids, attention_mask, category_label_seq_tensor=category_seq)

            ce_predictions = torch.argmax(ce_logits, dim = -1)

            ce_result = tf_id_to_name[ce_predictions[0]]

            if ce_result == 'True':
                with torch.no_grad():
                    _, pc_logits = pc_model(input_ids, attention_mask)

                pc_predictions = torch.argmax(pc_logits, dim=-1)
                pc_result = polarity_id_to_name[pc_predictions[0]]

                sentence['annotation'].append([pair, pc_result])


    return data

## F1 score 계산 - 추출 성능 및 전체 성능에 대한 F1 score 따로 계산

In [31]:
def evaluation_f1(true_data, pred_data):

    true_data_list = true_data
    pred_data_list = pred_data

    ce_eval = {
        'TP': 0,
        'FP': 0,
        'FN': 0,
        'TN': 0
    }

    pipeline_eval = {
        'TP': 0,
        'FP': 0,
        'FN': 0,
        'TN': 0
    }

    for i in range(len(true_data_list)):

        # TP, FN checking
        is_ce_found = False
        is_pipeline_found = False
        for y_ano  in true_data_list[i]['annotation']:
            y_category = y_ano[0]
            y_polarity = y_ano[2]

            for p_ano in pred_data_list[i]['annotation']:
                p_category = p_ano[0]
                p_polarity = p_ano[1]

                if y_category == p_category:
                    is_ce_found = True
                    if y_polarity == p_polarity:
                        is_pipeline_found = True

                    break

            if is_ce_found is True:
                ce_eval['TP'] += 1
            else:
                ce_eval['FN'] += 1

            if is_pipeline_found is True:
                pipeline_eval['TP'] += 1
            else:
                pipeline_eval['FN'] += 1

            is_ce_found = False
            is_pipeline_found = False

        # FP checking
        for p_ano in pred_data_list[i]['annotation']:
            p_category = p_ano[0]
            p_polarity = p_ano[1]

            for y_ano  in true_data_list[i]['annotation']:
                y_category = y_ano[0]
                y_polarity = y_ano[2]

                if y_category == p_category:
                    is_ce_found = True
                    if y_polarity == p_polarity:
                        is_pipeline_found = True

                    break

            if is_ce_found is False:
                ce_eval['FP'] += 1

            if is_pipeline_found is False:
                pipeline_eval['FP'] += 1

    ce_precision = ce_eval['TP']/(ce_eval['TP']+ce_eval['FP'])
    ce_recall = ce_eval['TP']/(ce_eval['TP']+ce_eval['FN'])

    ce_result = {
        'Precision': ce_precision,
        'Recall': ce_recall,
        'F1': 2*ce_recall*ce_precision/(ce_recall+ce_precision)
    }

    pipeline_precision = pipeline_eval['TP']/(pipeline_eval['TP']+pipeline_eval['FP'])
    pipeline_recall = pipeline_eval['TP']/(pipeline_eval['TP']+pipeline_eval['FN'])

    pipeline_result = {
        'Precision': pipeline_precision,
        'Recall': pipeline_recall,
        'F1': 2*pipeline_recall*pipeline_precision/(pipeline_recall+pipeline_precision)
    }

    return {
        'category extraction result': ce_result,
        'entire pipeline result': pipeline_result
    }

## 테스트 데이터에 대한 평가

In [37]:
def test_sentiment_analysis():

    tokenizer = ElectraTokenizer.from_pretrained(base_model)
    num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)
    test_data = jsonlload(dev_data_path)

    entity_property_test_data, polarity_test_data = get_dataset(test_data, tokenizer, max_len)

    entity_property_test_dataloader = DataLoader(entity_property_test_data, shuffle=True,
                                batch_size=batch_size)

    polarity_test_dataloader = DataLoader(polarity_test_data, shuffle=True,
                                                  batch_size=batch_size)
    
    model =  CategoryClassification(len(tf_id_to_name), len(tokenizer), lstm_hidden * 2, len(entity_property_pair), lstm_num_layer, bilstm_flag)
    model.load_state_dict(torch.load(test_category_extraction_model_path, map_location=device))
    model.to(device)
    model.eval()
            
    polarity_model = RoBertaBaseClassifier(len(polarity_id_to_name), len(tokenizer))
    polarity_model.load_state_dict(torch.load(test_polarity_classification_model_path, map_location=device))
    polarity_model.to(device)
    polarity_model.eval()

    pred_data = predict_from_korean_form(tokenizer, model, polarity_model, copy.deepcopy(test_data))

    # jsondump(pred_data, './pred_data.json')
    # pred_data = jsonload('./pred_data.json')

    print('F1 result: ', evaluation_f1(test_data, pred_data))

In [38]:
test_sentiment_analysis()

Some weights of the model checkpoint at monologg/koelectra-base-v3-discriminator were not used when initializing ElectraModel: ['discriminator_predictions.dense.bias', 'discriminator_predictions.dense_prediction.weight', 'discriminator_predictions.dense.weight', 'discriminator_predictions.dense_prediction.bias']
- This IS expected if you are initializing ElectraModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ElectraModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at monologg/koelectra-base-v3-discriminator were not used when initializing ElectraModel: ['discriminator_predictions.dense.bias', 'discriminator_predictions.dense_prediction.weight', 'discr

F1 result:  {'category extraction result': {'Precision': 0.7505446623093682, 'Recall': 0.6713218577460215, 'F1': 0.708726212926453}, 'entire pipeline result': {'Precision': 0.7357531760435572, 'Recall': 0.6583306268268918, 'F1': 0.6948920123414466}}
