In [1]:
import argparse
import logging
import os
import tarfile

import torch
from torch.utils.data import DataLoader
from dataset import KlueReProcessor
from transformers import AutoTokenizer
from utils import SKRelationExtractionDataset
from model import SkeletonAwareRoberta


from easydict import EasyDict

In [2]:
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    level=logging.INFO,
)
logger = logging.getLogger(__name__)

KLUE_RE_OUTPUT = "output.csv"

In [3]:
args = EasyDict({
    "batch_size": 64,
    "data_dir" : "./data",
    "model_dir": "./model",
    "model_tarname":"klue-re.tar.gz",
    "output_dir":os.environ.get("SM_OUTPUT_DATA_DIR", "/output"),
    "max_seq_length":512,
    "relation_filename" : "relation_list.json",
    "train_filename" : "klue-re-v1.1_train.json",
    "valid_filename" : "klue-re-v1.1_dev.json",
    "num_workers" : 4
})

In [4]:
def load_model(model_dir, model_tar_path):
    """load model from tar file pre-fetched from s3

    Args:
        model_dir: str: the directory of tar file
        model_tar_path: str: the name of tar file
    """
    tar = tarfile.open(model_tar_path, "r:gz")
    tar.extractall(path=model_dir)
    model = SkeletonAwareRoberta.from_pretrained(model_dir)
    return model

In [5]:
@torch.no_grad()
def inference(args) -> None:

    data_dir = args.data_dir
    model_dir = args.model_dir
    model_tar_path = os.path.join(model_dir, args.model_tarname)
    output_dir = args.output_dir

    assert os.path.exists(
        data_dir
    ), "Run inference code w/o data folder. Plz check out the path of data"

    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    output_file = open(os.path.join(output_dir, KLUE_RE_OUTPUT), "w")

    # configure gpu
    num_gpus = torch.cuda.device_count()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    logger.info("Loading model via model.SkeletonAwareRoberta")
    # load model
    model = load_model(model_dir, model_tar_path).to(device)
    if num_gpus > 1:
        model = torch.nn.DataParallel(model)
    model.eval()

    # load tokenizer
    logger.info("Loading tokenizer")
    tokenizer = AutoTokenizer.from_pretrained(model_dir, use_fast=False)

    logger.info("Data Loader : preprocessing data")
    # data_loader = KlueReDataLoader(args, tokenizer).get_dataloader(
    #     args.batch_size, num_workers=args.num_workers
    # )
    data_path = os.path.join(args.data_dir, args.test_filename)
    krp = KlueReProcessor(args,tokenizer)
    data_features = krp._convert_features(krp._create_examples(data_path))
    data_dataset = SKRelationExtractionDataset(data_features)
    data_loader = DataLoader(data_dataset,args.batch_size, drop_last = False)

    logger.info("Start inferencing")
    for batch in data_loader:
        input_data = {key : value.to(device) for key, value in batch.items() if not key == 'labels'}

        output = model(**input_data)

        logits = output[0]

        preds, probs = (
            torch.argmax(logits, dim=1).detach().cpu().numpy(),
            torch.softmax(logits, dim=1).detach().cpu().numpy(),
        )

        for i in range(len(preds)):
            output_file.write(f"{preds[i]}\t{' '.join(map(str,probs[i].tolist()))}\n")

    output_file.close()
    logger.info("Done inferencing")

In [6]:
args = EasyDict({
    "batch_size": 64,
    "data_dir" : "./data",
    "model_dir": "./model",
    "model_tarname":"klue-re.tar.gz",
    "output_dir":"./output",
    "max_seq_length":512,
    "relation_filename" : "relation_list.json",
    "test_filename" : "klue-re-v1.1_dev_sample_10.json",
    "num_workers" : 4
})

In [7]:
inference(args)

12/05/2021 18:19:16 - INFO - __main__ - Loading model via model.SkeletonAwareRoberta
12/05/2021 18:19:47 - INFO - __main__ - Loading tokenizer
12/05/2021 18:19:47 - INFO - __main__ - Data Loader : preprocessing data
12/05/2021 18:19:47 - INFO - dataset - Using BertTokenizer for fixing tokenization result
12/05/2021 18:19:47 - INFO - dataset - *** Example ***
12/05/2021 18:19:47 - INFO - dataset - guid: klue-re-v1_dev_00000
12/05/2021 18:19:47 - INFO - dataset - origin example: InputExample(guid='klue-re-v1_dev_00000', text_a="<si> 인물  수량 </si>[SEP]20대 남성 <subj>A</subj>(26)씨가 아버지 치료비를 위해 B(<obj>30</obj>)씨가 모아둔 돈을 훔쳐 인터넷 방송 BJ에게 '별풍선'으로 쏜 사실이 알려졌다.", text_b=None, label='no_relation')
12/05/2021 18:19:47 - INFO - dataset - origin tokens: ['<si>', '인물', '수량', '</si>', '[SEP]', '20', '##대', '남성', '<subj>', 'A', '</subj>', '(', '26', ')', '씨', '##가', '아버지', '치료비', '##를', '위해', 'B', '(', '<obj>', '30', '</obj>', ')', '씨', '##가', '모아', '##둔', '돈', '##을', '훔쳐', '인터넷', '방송', 'B', '##J', '##에', '