In [None]:
!pip install transformers

Collecting transformers
  Downloading transformers-4.9.2-py3-none-any.whl (2.6 MB)
[K     |████████████████████████████████| 2.6 MB 5.0 MB/s 
Collecting huggingface-hub==0.0.12
  Downloading huggingface_hub-0.0.12-py3-none-any.whl (37 kB)
Collecting tokenizers<0.11,>=0.10.1
  Downloading tokenizers-0.10.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (3.3 MB)
[K     |████████████████████████████████| 3.3 MB 43.7 MB/s 
[?25hCollecting sacremoses
  Downloading sacremoses-0.0.45-py3-none-any.whl (895 kB)
[K     |████████████████████████████████| 895 kB 43.2 MB/s 
Collecting pyyaml>=5.1
  Downloading PyYAML-5.4.1-cp37-cp37m-manylinux1_x86_64.whl (636 kB)
[K     |████████████████████████████████| 636 kB 70.4 MB/s 
Installing collected packages: tokenizers, sacremoses, pyyaml, huggingface-hub, transformers
  Attempting uninstall: pyyaml
    Found existing installation: PyYAML 3.13
    Uninstalling PyYAML-3.13:
      Successfully uninsta

In [None]:
import json
import numpy as np
from transformers import BertTokenizer
from tensorflow.keras.utils import get_file
import matplotlib.pyplot as plt
import pickle

# 학습용 
train_data_url = "https://korquad.github.io/dataset/KorQuAD_v1.0_train.json"
train_path = get_file("train.json", train_data_url) 

# 평가용 
eval_data_url = "https://korquad.github.io/dataset/KorQuAD_v1.0_dev.json"
eval_path = get_file("eval.json", eval_data_url)

train_data = json.load(open(train_path)) 
dev_data = json.load(open(eval_path))

print(train_path)
print(eval_path)

Downloading data from https://korquad.github.io/dataset/KorQuAD_v1.0_train.json
Downloading data from https://korquad.github.io/dataset/KorQuAD_v1.0_dev.json
/root/.keras/datasets/train.json
/root/.keras/datasets/eval.json


In [None]:
MAX_SEQ_LEN = 128
MAX_TRAIN_LEN = 50000  # 시간이 오래 걸려서 데이터 개수를 제한한다.
MAX_TEST_LEN = 1000
tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-cased', cache_dir='bert_ckpt')

In [None]:
def parsing(p_data, max_len = 50000):
    context = []
    question = []
    start_idx = []
    end_idx = []
    for item in p_data["data"]:
        for para in item["paragraphs"]:
            for qa in para["qas"]:
                i_start = qa["answers"][0]["answer_start"]
                s_answer = qa["answers"][0]["text"]
                i_end = i_start + len(s_answer)
                quest = qa["question"]
                
                if i_end < MAX_SEQ_LEN - len(quest):
                    context.append(para["context"])
                    question.append(quest)
                    start_idx.append(i_start)
                    end_idx.append(i_end)
    
    # question과 paragraph으로 BERT의 입력 데이터를 생성한다.
    qa_pairs = list(zip(question, context))
    qa_enc = tokenizer.batch_encode_plus(
                qa_pairs,
                add_special_tokens = True,
                padding = True,
                truncation = True, 
                max_length = MAX_SEQ_LEN,
                return_attention_mask = True,
                return_token_type_ids=True,
                return_tensors = 'tf')        

    x_ids = qa_enc['input_ids'].numpy()
    x_msk = qa_enc['attention_mask'].numpy()
    x_typ = qa_enc['token_type_ids'].numpy()
    
    # KorQuAD 모델의 최종 출력 target
    y_start = np.array(start_idx)
    y_end = np.array(end_idx)
        
    return x_ids, x_msk, x_typ, y_start, y_end

In [None]:
x_train_ids, x_train_msk, x_train_typ, y_train_start, y_train_end = parsing(train_data, max_len = MAX_TRAIN_LEN)
x_test_ids, x_test_msk, x_test_typ, y_test_start, y_test_end = parsing(dev_data, max_len = MAX_TEST_LEN)

In [None]:
# vocabulary를 저장한다.
with open('/data/vocabulary.pickle', 'wb') as f:
    pickle.dump(tokenizer.get_vocab(), f, pickle.DEFAULT_PROTOCOL)

# 학습 데이터를 저장한다.
with open('/data/train_encoded.pickle', 'wb') as f:
    pickle.dump([x_train_ids, x_train_msk, x_train_typ, y_train_start, y_train_end], f, pickle.DEFAULT_PROTOCOL)

# 시험 데이터를 저장한다.
with open('/data/test_encoded.pickle', 'wb') as f:
    pickle.dump([x_test_ids, x_test_msk, x_test_typ, y_test_start, y_test_end], f, pickle.DEFAULT_PROTOCOL)
