In [1]:
from model import SkeletonAwareBERT, SkeletonAwareRoberta
from transformers import AutoTokenizer, AutoConfig,Trainer, TrainingArguments
from easydict import EasyDict
from dataset import KlueReProcessor
from utils import SKRelationExtractionDataset, seed_everything
import json
import os
import torch
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import numpy as np
import pandas as pd

In [2]:
seed_everything(42)

In [3]:
args = EasyDict({
    "batch_size": 64,
    "data_dir" : "./data",
    "model_dir": "./model",
    "model_tarname":"klue-re.tar.gz",
    "output_dir":os.environ.get("SM_OUTPUT_DATA_DIR", "/output"),
    "max_seq_length":512,
    "relation_filename" : "relation_list.json",
    "train_filename" : "klue-re-v1.1_train.json",
    "valid_filename" : "klue-re-v1.1_dev.json",
    "num_workers" : 4
})

In [4]:
# 환경 설정
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

relation_class_file_path = os.path.join(args.data_dir, args.relation_filename)
valid_file_path = os.path.join(args.data_dir, args.valid_filename)
with open(relation_class_file_path, "r", encoding="utf-8") as f:
    relation_class = json.load(f)["relations"]

model_name_or_path = args.model_dir
config = AutoConfig.from_pretrained(model_name_or_path)
tokenizer = AutoTokenizer.from_pretrained(config.name_or_path,use_fast=False) 
krp = KlueReProcessor(args,tokenizer)

In [5]:
# 데이터 준비
valid_features = krp._convert_features(krp._create_examples(valid_file_path))
valid_dataset = SKRelationExtractionDataset(valid_features)
valid_loader = DataLoader(valid_dataset,args.batch_size, drop_last = False)

In [6]:
# 모델 준비
model = SkeletonAwareRoberta.from_pretrained(args.model_dir)
model.to(device)
model.eval()

SkeletonAwareRoberta(
  (roberta): RobertaModel(
    (embeddings): RobertaEmbeddings(
      (word_embeddings): Embedding(32008, 1024, padding_idx=1)
      (position_embeddings): Embedding(514, 1024, padding_idx=1)
      (token_type_embeddings): Embedding(1, 1024)
      (LayerNorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): RobertaEncoder(
      (layer): ModuleList(
        (0): RobertaLayer(
          (attention): RobertaAttention(
            (self): RobertaSelfAttention(
              (query): Linear(in_features=1024, out_features=1024, bias=True)
              (key): Linear(in_features=1024, out_features=1024, bias=True)
              (value): Linear(in_features=1024, out_features=1024, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): RobertaSelfOutput(
              (dense): Linear(in_features=1024, out_features=1024, bias=True)
              (LayerN

In [7]:
all_preds = []
all_probs = []
with torch.no_grad():
    for batch in tqdm(valid_loader):
        input_data = {key : value.to(device) for key, value in batch.items() if not key == 'labels'}

        output = model(**input_data)
        logits = output[0]

        preds, probs = (
            torch.argmax(logits, dim=1).detach().cpu().numpy().tolist(),
            torch.softmax(logits, dim=1).detach().cpu().numpy().tolist(),
        )
        all_preds.extend(preds)
        all_probs.extend(probs)

100%|██████████| 122/122 [01:42<00:00,  1.19it/s]


In [8]:
with open(valid_file_path, "r", encoding="utf-8") as f:
    valid_json = json.load(f)

In [9]:
valid_labels = [data['label'] for data in valid_json]
pred_labels = [relation_class[pred] for pred in all_preds]

In [10]:
valid_sen = [data['sentence'] for data in valid_json]

In [11]:
df = pd.DataFrame(data={'sen': valid_sen, 'pred':pred_labels, 'label':valid_labels})

In [12]:
wrong_ans = df[df['pred'] != df['label']].reset_index()

In [13]:
for relation in relation_class:
    rel_total = len(df[df['label'] == relation])
    wrg_total = len(wrong_ans[wrong_ans['label'] == relation])
    print(relation,':', rel_total, '->',wrg_total, '|', (wrg_total/rel_total) * 100)

no_relation : 4631 -> 671 | 14.489311163895488
org:dissolved : 11 -> 4 | 36.36363636363637
org:founded : 20 -> 4 | 20.0
org:place_of_headquarters : 194 -> 40 | 20.618556701030926
org:alternate_names : 78 -> 22 | 28.205128205128204
org:member_of : 104 -> 55 | 52.88461538461539
org:members : 122 -> 47 | 38.52459016393443
org:political/religious_affiliation : 13 -> 4 | 30.76923076923077
org:product : 235 -> 71 | 30.21276595744681
org:founded_by : 11 -> 3 | 27.27272727272727
org:top_members/employees : 513 -> 201 | 39.1812865497076
org:number_of_employees/members : 17 -> 6 | 35.294117647058826
per:date_of_birth : 12 -> 2 | 16.666666666666664
per:date_of_death : 13 -> 1 | 7.6923076923076925
per:place_of_birth : 11 -> 2 | 18.181818181818183
per:place_of_death : 10 -> 1 | 10.0
per:place_of_residence : 124 -> 78 | 62.903225806451616
per:origin : 118 -> 24 | 20.33898305084746
per:employee_of : 242 -> 50 | 20.66115702479339
per:schools_attended : 11 -> 4 | 36.36363636363637
per:alternate_names :

In [27]:
wrong_ans[(wrong_ans['label'] == "per:date_of_birth")]['pred'].value_counts()

no_relation    2
Name: pred, dtype: int64