In [1]:
import os
import tarfile

import torch
from dataloader import KlueMrcDataLoaderGetter
from tqdm import tqdm
from transformers import AutoConfig, AutoModelForQuestionAnswering, AutoTokenizer, BertTokenizer
from transformers.data.metrics.squad_metrics import compute_predictions_logits
from transformers.data.processors.squad import SquadResult
from custom_model.model import RobertaQASplit
from easydict import EasyDict

KLUE_MRC_OUTPUT = "output.csv"  # the name of output file should be output.csv


In [2]:
def load_model_and_type(model_dir: str, model_tar_file: str):
    """load model and model type from tar file pre-fetched from s3

    Args:
        model_dir: str: the directory of tar file
        model_tar_path: str: the name of tar file
    """
    tarpath = os.path.join(model_dir, model_tar_file)
    tar = tarfile.open(tarpath, "r:gz")
    tar.extractall(path=model_dir)

    model = RobertaQASplit.from_pretrained(model_dir)
    config = AutoConfig.from_pretrained(model_dir)
    return model, config.model_type


@torch.no_grad()
def inference(data_dir: str, model_dir: str, output_dir: str, args) -> None:
    # configure gpu
    num_gpus = torch.cuda.device_count()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # load model
    model, model_type = load_model_and_type(model_dir, args.model_tar_file)
    model.to(device)
    if num_gpus > 1:
        model = torch.nn.DataParallel(model)
    model.eval()
    kwargs = (
        {"num_workers": num_gpus, "pin_memory": True}
        if torch.cuda.is_available()
        else {}
    )
    # load tokenizer
    tokenizer = BertTokenizer.from_pretrained(model_dir)
    # infer
    output_file_path = os.path.join(output_dir, KLUE_MRC_OUTPUT)

    klue_mrc_dataloader_getter = KlueMrcDataLoaderGetter(
        tokenizer=tokenizer,
        max_seq_length=args.max_seq_length,
        doc_stride=args.doc_stride,
        max_query_length=args.max_query_length,
    )
    klue_mrc_dataloader = klue_mrc_dataloader_getter.get_dataloader(
        file_path=os.path.join(data_dir, args.test_filename),
        batch_size=args.batch_size,
        **kwargs
    )
    qa_result = list()
    for data in tqdm(klue_mrc_dataloader):
        input_ids, attention_mask, token_type_ids, idx = data
        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)
        token_type_ids = token_type_ids.to(device)

        # roberta does not accept token_type_id > 1
        if model_type == "roberta":
            token_type_ids = None

        # start_logits, end_logits = model(
        #     input_ids=input_ids,
        #     attention_mask=attention_mask,
        #     token_type_ids=token_type_ids,
        # )

        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
        )

        batch_results = list()
        # print(start_logits)
        # print(type(start_logits))
        # print(outputs)
        # print(type(outputs))

        for i, feature_index in enumerate(idx):
            unique_id = klue_mrc_dataloader.dataset.features[feature_index].unique_id
            single_example_start_logits = outputs.start_logits[i].tolist() # start_logits[i].tolist()
            single_example_end_logits = outputs.end_logits[i].tolist()
            batch_results.append(
                SquadResult(
                    unique_id, single_example_start_logits, single_example_end_logits
                )
            )

        qa_result.extend(batch_results)

    examples = klue_mrc_dataloader.dataset.examples
    features = klue_mrc_dataloader.dataset.features
    do_lower_case = getattr(tokenizer, "do_lower_case", False)

    preds = compute_predictions_logits(
        all_examples=examples,
        all_features=features,
        all_results=qa_result,
        n_best_size=args.n_best_size,
        max_answer_length=args.max_answer_length,
        do_lower_case=do_lower_case,
        output_prediction_file=output_file_path,
        output_nbest_file=None,
        output_null_log_odds_file=None,
        verbose_logging=False,
        version_2_with_negative=True,
        null_score_diff_threshold=0,
        tokenizer=tokenizer,
    )
    return preds

In [5]:
if __name__ == "__main__":
    args = EasyDict({
        "batch_size": 30,
        "data_dir" : "./data",
        "model_dir": "./model",
        "model_tar_file":"klue_mrc_model.tar.gz",
        "output_dir":"./model",
        "max_seq_length":384,
        "max_query_length" : 64,
        "test_filename" : "klue-mrc-v1.1_dev_sample_10.json",
        "doc_stride": 128,
        "n_best_size": 20,
        "max_answer_length":30,
        "num_workers" : 4
    })

    data_dir = args.data_dir
    model_dir = args.model_dir
    output_dir = args.output_dir
    inference(data_dir, model_dir, output_dir, args)

convert squad examples to features: 100%|██████████| 13/13 [00:00<00:00, 52.74it/s]
add example index and unique id: 100%|██████████| 13/13 [00:00<00:00, 119574.46it/s]
100%|██████████| 2/2 [00:00<00:00,  2.03it/s]


In [1]:
a = {'a': 1}

In [2]:
len(a)

1