In [49]:
import pandas as pd
import os
import json
import pickle
from tqdm.auto import tqdm
import numpy as np
from typing import List
import re
from fast_bm25 import BM25
from sklearn.model_selection import train_test_split
from utils.utils import *
import random
from collections import Counter, defaultdict

seed_everything(42)

# KorQuAD

In [2]:
def make_instance(data:List):
    output = []
    for article in tqdm(data):
        if article.get('title') is not None:
            title = article['title']
        elif article.get('doc_title') is not None:
            title = article['doc_title']
        else:
            raise Exception('title error')
        title = re.sub('_',' ',title)
        for paragraph in article['paragraphs']:
            context = paragraph['context']
            #contexts.append(dict(title=title, context=context))
            for qas in paragraph['qas']:
                question = qas['question']
                if type(qas['answers'])==list:
                    answer = [a['text'] for a in qas['answers']]
                else:
                    answer = qas['answers']['text']
                if qas.get('is_impossible') is not None:
                    if qas['is_impossible']==True:
                        output.append(dict(title=title, context=context, question=question, answer=answer, label=0))
                    else:
                        output.append(dict(title=title, context=context, question=question, answer=answer, label=1))
                else:
                    output.append(dict(title=title, context=context, question=question, answer=answer, label=1))
    return output

In [3]:
def no_answer_make_instance(data:List):
    output = []
    for article in tqdm(data):
        if article.get('title') is not None:
            title = article['title']
        elif article.get('doc_title') is not None:
            title = article['doc_title']
        else:
            raise Exception('title error')
        title = re.sub('_',' ',title)
        for paragraph in article['paragraphs']:
            context = paragraph['context']
            for qas in paragraph['qas']:
                question = qas['question']
                answer = qas['answers'] if qas.get('answers') is not None else None
                output.append(dict(title=title, context=context, question=question, answer=answer, label=0))
    return output

In [5]:
def annotation(data_list): # train, val, test 따로 따로 진행 요망.
    output = []
    total_contexts = set()
    c = 0
    label_1 = 0
    label_0 = 0
    label_0_0 = 0
    for i in tqdm(data_list,desc='make_context'):
        for j in i:
            c+=1
            total_contexts.add(j['title']+' '.join(j['context'].split())) # 전처리
    total_contexts = list(total_contexts)
    for i in data_list:
        for j in tqdm(i, desc='attach'):
            q = j['question']
            a = j['answer']
            if j['label']==1:
                pos = j['title']+' '+' '.join(j['context'].split())
                output.append(dict(question=q, answer=a, context=pos, label=1))
                label_1+=1
                while True:
                    tmp = random.choice(total_contexts)
                    # list가 아니면 None임 - is impossible임.
                    if type(j['answer'])==list:
                        for answer in j['answer']:
                            if answer in tmp:
                                break
                        else:
                            label_0+=1
                            output.append(dict(question=q, answer=a, context=tmp, label=0))
                            break 
                    else:
                        if a not in tmp:
                            label_0+=1
                            output.append(dict(question=q, answer=a, context=tmp, label=0))
                            break                        
            else:
                label_0_0 +=1
                neg = j['title']+' '+j['context']
                output.append(dict(question=q, answer=a, context=neg, label=2))
    print(f'positive 개수 : {label_1}, negative 개수 : {label_0}, is_impossible neagtive 개수 : {label_0_0}, total negative 개수 : {label_0+label_0_0}')
    
    return output

In [4]:
korquad_train = json.load(open('../../data/MRC_data/KorQuAD 1.0/KorQuAD_v1.0_train.json','r'))
korquad_dev = json.load(open('../../data/MRC_data/KorQuAD 1.0/KorQuAD_v1.0_dev.json','r'))

In [5]:
korquad_train = make_instance(korquad_train['data'])
korquad_dev = make_instance(korquad_dev['data'])

100%|████████████████████████████████████████████████████████████████████████████| 1420/1420 [00:00<00:00, 9343.06it/s]
100%|█████████████████████████████████████████████████████████████████████████████| 140/140 [00:00<00:00, 21537.65it/s]


In [6]:
korquad_train, korquad_test = train_test_split(korquad_train, test_size = 0.1, random_state=42, shuffle=True)

In [7]:
for name,data in zip(['train','dev','test'],[korquad_train, korquad_test, korquad_dev]):
    dpr = []
    for i in data:
        if i['label']==1:
            dpr.append(dict(question=i['question'], answer = i['answer'], positive_ctxs=[dict(title=i['title'], context=' '.join(i['context'].split()))]))            
    save_jsonl('../../data/dpr/korquad',dpr,name)

100%|█████████████████████████████████████████████████████████████████████████| 54366/54366 [00:00<00:00, 84061.33it/s]
100%|███████████████████████████████████████████████████████████████████████████| 6041/6041 [00:00<00:00, 82450.53it/s]
100%|███████████████████████████████████████████████████████████████████████████| 5774/5774 [00:00<00:00, 90034.10it/s]


## rerank

In [8]:
save_jsonl('../../data/rerank/v2/korquad',korquad_train,'train')
save_jsonl('../../data/rerank/v2/korquad',korquad_test,'dev')
save_jsonl('../../data/rerank/v2/korquad',korquad_dev,'test')

100%|█████████████████████████████████████████████████████████████████████████| 54366/54366 [00:00<00:00, 84193.24it/s]
100%|███████████████████████████████████████████████████████████████████████████| 6041/6041 [00:00<00:00, 89252.77it/s]
100%|███████████████████████████████████████████████████████████████████████████| 5774/5774 [00:00<00:00, 85604.39it/s]


# 기계독해 - dev 존재 x

In [9]:
mrc_train_1 = json.load(open('../../data/MRC_data/기계독해/기계독해분야/01.Normal/ko_nia_normal_squad_all.json','r',encoding='utf-8'))
mrc_train_2 = json.load(open('../../data/MRC_data/기계독해/기계독해분야/02.NoAnswer/ko_nia_noanswer_squad_all.json','r',encoding='utf-8')) # no answer
mrc_train_3 = json.load(open('../../data/MRC_data/기계독해/기계독해분야/03.Clue_/ko_nia_clue0529_squad_all.json','r',encoding='utf-8'))

In [10]:
mrc_train_1 = make_instance(mrc_train_1['data'])
mrc_train_2 = no_answer_make_instance(mrc_train_2['data'])
mrc_train_3 = make_instance(mrc_train_3['data'])

100%|█████████████████████████████████████████████████████████████████████████| 47314/47314 [00:01<00:00, 27173.84it/s]
100%|████████████████████████████████████████████████████████████████████████| 20030/20030 [00:00<00:00, 356824.83it/s]
100%|████████████████████████████████████████████████████████████████████████| 34500/34500 [00:00<00:00, 256658.85it/s]


In [11]:
mrc_train_1, mrc_test_1 = train_test_split(mrc_train_1,test_size = 0.1,random_state=42, shuffle=True)
mrc_train_1, mrc_dev_1 = train_test_split(mrc_train_1,test_size = 0.1,random_state=42, shuffle=True)
mrc_train_2, mrc_test_2 = train_test_split(mrc_train_2,test_size = 0.1,random_state=42, shuffle=True)
mrc_train_2, mrc_dev_2 = train_test_split(mrc_train_2,test_size = 0.1,random_state=42, shuffle=True)
mrc_train_3, mrc_test_3 = train_test_split(mrc_train_3,test_size = 0.1,random_state=42, shuffle=True)
mrc_train_3, mrc_dev_3 = train_test_split(mrc_train_3,test_size = 0.1,random_state=42, shuffle=True)

## save for dpr

In [12]:
for name,data in zip(['train','dev','test'],[mrc_train_1+mrc_train_3, mrc_dev_1+mrc_dev_3, mrc_test_1+mrc_test_3]):
    dpr = []
    for i in data:
        if i['label']==1:
            dpr.append(dict(question=i['question'], answer = i['answer'], positive_ctxs=[dict(title=i['title'], context=' '.join(i['context'].split()))]))            
    save_jsonl('../../data/dpr/mrc',dpr,name)

100%|███████████████████████████████████████████████████████████████████████| 275469/275469 [00:04<00:00, 61489.06it/s]
100%|█████████████████████████████████████████████████████████████████████████| 30609/30609 [00:00<00:00, 56709.80it/s]
100%|█████████████████████████████████████████████████████████████████████████| 34010/34010 [00:00<00:00, 60533.04it/s]


## rerank

In [13]:
train_data = mrc_train_1+mrc_train_2+mrc_train_3
dev_data = mrc_dev_1+mrc_dev_2+mrc_dev_3
test_data = mrc_test_1+mrc_test_2+mrc_test_3

In [14]:
save_jsonl('../../data/rerank/v2/mrc',train_data,'train')
save_jsonl('../../data/rerank/v2/mrc',dev_data,'dev')
save_jsonl('../../data/rerank/v2/mrc',test_data,'test')

100%|███████████████████████████████████████████████████████████████████████| 356666/356666 [00:05<00:00, 64607.82it/s]
100%|█████████████████████████████████████████████████████████████████████████| 39631/39631 [00:00<00:00, 63268.36it/s]
100%|█████████████████████████████████████████████████████████████████████████| 44035/44035 [00:00<00:00, 62719.79it/s]


# 도서자료 기계독해 - dev 존재

In [None]:
book_mrc_train = json.load(open('../../data/MRC_data/도서자료 기계독해/Training/도서_train/도서_220419_add/도서_220419_add.json','r',encoding='utf-8'))
book_mrc_dev = json.load(open('../../data/MRC_data/도서자료 기계독해/Validation/도서_valid/도서.json','r',encoding='utf-8'))

In [16]:
book_mrc_train = make_instance(book_mrc_train['data'])
book_mrc_dev = make_instance(book_mrc_dev['data'])

100%|████████████████████████████████████████████████████████████████████████████| 5368/5368 [00:03<00:00, 1379.38it/s]
100%|███████████████████████████████████████████████████████████████████████████| 1994/1994 [00:00<00:00, 41668.04it/s]


In [17]:
book_mrc_train, book_mrc_test = train_test_split(book_mrc_train,test_size = 0.05,random_state=42, shuffle=True)

In [18]:
for name,data in zip(['train','dev','test'],[book_mrc_train, book_mrc_test, book_mrc_dev]):
    dpr = []
    for i in data:
        if i['label']==1:
            dpr.append(dict(question=i['question'], answer = i['answer'], positive_ctxs=[dict(title=i['title'], context=' '.join(i['context'].split()))]))            
    save_jsonl('../../data/dpr/book',dpr,name)

100%|███████████████████████████████████████████████████████████████████████| 598504/598504 [00:08<00:00, 73855.31it/s]
100%|█████████████████████████████████████████████████████████████████████████| 31496/31496 [00:00<00:00, 75922.42it/s]
100%|█████████████████████████████████████████████████████████████████████████| 35000/35000 [00:00<00:00, 85865.40it/s]


## rerank

In [19]:
def annotation(data):
    if isinstance(data,list):
        output = []
        for i in data:
            output+=i
        return output
    else:
        return data

In [20]:
train_data = annotation([book_mrc_train])
dev_data = annotation([book_mrc_test])
test_data = annotation([book_mrc_dev])

In [21]:
save_jsonl('../../data/rerank/v2/book',train_data, 'train')
save_jsonl('../../data/rerank/v2/book',dev_data, 'dev')
save_jsonl('../../data/rerank/v2/book',test_data, 'test')

100%|███████████████████████████████████████████████████████████████████████| 855000/855000 [00:11<00:00, 76514.73it/s]
100%|█████████████████████████████████████████████████████████████████████████| 45000/45000 [00:00<00:00, 72788.52it/s]
100%|█████████████████████████████████████████████████████████████████████████| 50000/50000 [00:00<00:00, 82577.69it/s]


# 행정문서 기계독해

In [22]:
ad_train_1 = json.load(open('../../data/MRC_data/행정 문서 대상 기계독해/01.데이터/1.Training/라벨링데이터/TL_multiple_choice/TL_multiple_choice.json','r',encoding='utf-8'))
ad_train_2 = json.load(open('../../data/MRC_data/행정 문서 대상 기계독해/01.데이터/1.Training/라벨링데이터/TL_span_extraction/TL_span_extraction.json','r',encoding='utf-8'))
ad_train_3 = json.load(open('../../data/MRC_data/행정 문서 대상 기계독해/01.데이터/1.Training/라벨링데이터/TL_span_extraction_how/TL_span_extraction_how.json','r',encoding='utf-8'))
ad_train_4 = json.load(open('../../data/MRC_data/행정 문서 대상 기계독해/01.데이터/1.Training/라벨링데이터/TL_unanswerable/TL_unanswerable.json','r',encoding='utf-8')) # NO
ad_train_5 = json.load(open('../../data/MRC_data/행정 문서 대상 기계독해/01.데이터/1.Training/라벨링데이터/TL_text_entailment/TL_text_entailment.json','r',encoding='utf-8'))

ad_dev_1 = json.load(open('../../data/MRC_data/행정 문서 대상 기계독해/01.데이터/2.Validation/라벨링데이터/VL_multiple_choice/VL_multiple_choice.json','r',encoding='utf-8'))
ad_dev_2 = json.load(open('../../data/MRC_data/행정 문서 대상 기계독해/01.데이터/2.Validation/라벨링데이터/VL_span_extraction/VL_span_extraction.json','r',encoding='utf-8'))
ad_dev_3 = json.load(open('../../data/MRC_data/행정 문서 대상 기계독해/01.데이터/2.Validation/라벨링데이터/VL_span_extraction_how/VL_span_extraction_how.json','r',encoding='utf-8'))
# NO
ad_dev_4 = json.load(open('../../data/MRC_data/행정 문서 대상 기계독해/01.데이터/2.Validation/라벨링데이터/VL_unanswerable/VL_unanswerable.json','r',encoding='utf-8'))
ad_dev_5 = json.load(open('../../data/MRC_data/행정 문서 대상 기계독해/01.데이터/2.Validation/라벨링데이터/VL_text_entailment/VL_text_entailment.json','r',encoding='utf-8'))

In [23]:
ad_train_1 =make_instance(ad_train_1['data'])
ad_train_2 =make_instance(ad_train_2['data'])
ad_train_3 =make_instance(ad_train_3['data'])
ad_train_4 =no_answer_make_instance(ad_train_4['data'])
ad_train_5 =make_instance(ad_train_5['data'])

ad_dev_1 = make_instance(ad_dev_1['data'])
ad_dev_2 = make_instance(ad_dev_2['data'])
ad_dev_3 = make_instance(ad_dev_3['data'])
ad_dev_4 = no_answer_make_instance(ad_dev_4['data'])
ad_dev_5 = make_instance(ad_dev_5['data'])

100%|████████████████████████████████████████████████████████████████████████| 15085/15085 [00:00<00:00, 376878.26it/s]
100%|████████████████████████████████████████████████████████████████████████| 63932/63932 [00:00<00:00, 481721.45it/s]
100%|████████████████████████████████████████████████████████████████████████| 29074/29074 [00:00<00:00, 429260.55it/s]
100%|██████████████████████████████████████████████████████████████████████████| 9828/9828 [00:00<00:00, 456885.93it/s]
100%|████████████████████████████████████████████████████████████████████████| 16872/16872 [00:00<00:00, 456512.58it/s]
100%|██████████████████████████████████████████████████████████████████████████| 2813/2813 [00:00<00:00, 344886.79it/s]
100%|████████████████████████████████████████████████████████████████████████| 11985/11985 [00:00<00:00, 505869.25it/s]
100%|██████████████████████████████████████████████████████████████████████████| 5458/5458 [00:00<00:00, 491888.94it/s]
100%|███████████████████████████████████

In [24]:
ad_train_1,ad_test_1 = train_test_split(ad_train_1,test_size = 0.1,random_state=42, shuffle=True)
ad_train_2,ad_test_2 = train_test_split(ad_train_2,test_size = 0.1,random_state=42, shuffle=True)
ad_train_3, ad_test_3 = train_test_split(ad_train_3,test_size = 0.1,random_state=42, shuffle=True)
ad_train_4, ad_test_4 = train_test_split(ad_train_4,test_size = 0.1,random_state=42, shuffle=True)
ad_train_5, ad_test_5 = train_test_split(ad_train_5,test_size = 0.1,random_state=42, shuffle=True)

In [25]:
ad_train = ad_train_1+ad_train_2+ad_train_3+ad_train_5
ad_dev = ad_dev_1+ad_dev_2+ad_dev_3+ad_dev_5
ad_test = ad_test_1+ad_test_2+ad_test_3+ad_test_5

In [26]:
for name,data in zip(['train','dev','test'],[ad_train, ad_test, ad_dev]):
    dpr = []
    for i in data:
        if i['label']==1:
            dpr.append(dict(question=i['question'], answer = i['answer'], positive_ctxs=[dict(title=i['title'], context=' '.join(i['context'].split()))]))            
    path = '../../data/dpr/ad'
    os.makedirs(path, exist_ok=True)
    save_jsonl(path,dpr,name)

100%|███████████████████████████████████████████████████████████████████████| 187391/187391 [00:02<00:00, 77568.54it/s]
100%|█████████████████████████████████████████████████████████████████████████| 20822/20822 [00:00<00:00, 78296.85it/s]
100%|█████████████████████████████████████████████████████████████████████████| 26026/26026 [00:00<00:00, 88499.58it/s]


In [27]:
train_data = annotation([ad_train, ad_train_4])
dev_data = annotation([ad_test, ad_test_4])
test_data = annotation([ad_dev, ad_dev_4])

In [28]:
save_jsonl('../../data/rerank/v2/ad',train_data, 'train')
save_jsonl('../../data/rerank/v2/ad',dev_data, 'dev')
save_jsonl('../../data/rerank/v2/ad',test_data, 'test')

100%|███████████████████████████████████████████████████████████████████████| 202151/202151 [00:02<00:00, 71200.03it/s]
100%|█████████████████████████████████████████████████████████████████████████| 22462/22462 [00:00<00:00, 69069.15it/s]
100%|█████████████████████████████████████████████████████████████████████████| 28076/28076 [00:00<00:00, 78974.08it/s]


# 뉴스 기사 기계독해

In [29]:
news_mrc_train_1 = json.load(open('../../data/MRC_data/뉴스 기사 기계독해/01.데이터/1.Training/라벨링데이터/TL_text_entailment/TL_text_entailment.json','r',encoding='utf-8'))
news_mrc_train_2 = json.load(open('../../data/MRC_data/뉴스 기사 기계독해/01.데이터/1.Training/라벨링데이터/TL_span_extraction/TL_span_extraction.json','r',encoding='utf-8'))
news_mrc_train_3 = json.load(open('../../data/MRC_data/뉴스 기사 기계독해/01.데이터/1.Training/라벨링데이터_221115_add/TL_span_inference/TL_span_inference.json','r',encoding='utf-8'))
news_mrc_train_4 = json.load(open('../../data//MRC_data/뉴스 기사 기계독해/01.데이터/1.Training/라벨링데이터/TL_unanswerable/TL_unanswerable.json','r',encoding='utf-8'))
news_mrc_dev_1 = json.load(open('../../data/MRC_data/뉴스 기사 기계독해/01.데이터/2.Validation/라벨링데이터/VL_span_extraction/VL_span_extraction.json','r',encoding='utf-8'))
news_mrc_dev_2 = json.load(open('../../data/MRC_data/뉴스 기사 기계독해/01.데이터/2.Validation/라벨링데이터/VL_span_inference/VL_span_inference.json','r',encoding='utf-8'))
news_mrc_dev_3 = json.load(open('../../data/MRC_data/뉴스 기사 기계독해/01.데이터/2.Validation/라벨링데이터/VL_text_entailment/VL_text_entailment.json','r',encoding='utf-8'))
news_mrc_dev_4 = json.load(open('../../data//MRC_data/뉴스 기사 기계독해/01.데이터/2.Validation/라벨링데이터/VL_unanswerable/VL_unanswerable.json','r',encoding='utf-8'))

In [30]:
news_mrc_train_1 = make_instance(news_mrc_train_1['data'])
news_mrc_train_2 = make_instance(news_mrc_train_2['data'])
news_mrc_train_3 = make_instance(news_mrc_train_3['data'])
news_mrc_train_4 = no_answer_make_instance(news_mrc_train_4['data'])
news_mrc_dev_1 = make_instance(news_mrc_dev_1['data'])
news_mrc_dev_2 = make_instance(news_mrc_dev_2['data'])
news_mrc_dev_3 = make_instance(news_mrc_dev_3['data'])
news_mrc_dev_4 = no_answer_make_instance(news_mrc_dev_4['data'])

100%|████████████████████████████████████████████████████████████████████████| 24009/24009 [00:00<00:00, 419528.25it/s]
100%|██████████████████████████████████████████████████████████████████████| 111967/111967 [00:00<00:00, 385642.31it/s]
100%|████████████████████████████████████████████████████████████████████████| 15992/15992 [00:00<00:00, 420574.41it/s]
100%|██████████████████████████████████████████████████████████████████████████| 8000/8000 [00:00<00:00, 424465.62it/s]
100%|████████████████████████████████████████████████████████████████████████| 13997/13997 [00:00<00:00, 460661.89it/s]
100%|██████████████████████████████████████████████████████████████████████████| 1999/1999 [00:00<00:00, 294355.21it/s]
100%|██████████████████████████████████████████████████████████████████████████| 3001/3001 [00:00<00:00, 286064.10it/s]
100%|██████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 500036.24it/s]


In [31]:
news_mrc_train_1, news_mrc_test_1 = train_test_split(news_mrc_train_1,test_size = 0.1,random_state=42, shuffle=True)
news_mrc_train_2, news_mrc_test_2 = train_test_split(news_mrc_train_2,test_size = 0.1,random_state=42, shuffle=True)
news_mrc_train_3, news_mrc_test_3 = train_test_split(news_mrc_train_3,test_size = 0.1,random_state=42, shuffle=True)

In [32]:
news_mrc_train_4, news_mrc_test_4 = train_test_split(news_mrc_train_4,test_size = 0.1,random_state=42, shuffle=True)

In [33]:
news_train = news_mrc_train_1+news_mrc_train_2+news_mrc_train_3
news_dev = news_mrc_dev_1+news_mrc_dev_2+news_mrc_dev_3
news_test = news_mrc_test_1+news_mrc_test_2+news_mrc_test_3

In [34]:
for name,data in zip(['train','dev','test'],[news_train, news_test, news_dev]):
    dpr = []
    for i in data:
        if i['label']==1:
            dpr.append(dict(question=i['question'], answer = i['answer'], positive_ctxs=[dict(title=i['title'], context=' '.join(i['context'].split()))]))            
    path = '../../data/dpr/news'
    os.makedirs(path, exist_ok=True)
    save_jsonl(path,dpr,name)

100%|███████████████████████████████████████████████████████████████████████| 273543/273543 [00:04<00:00, 62530.06it/s]
100%|█████████████████████████████████████████████████████████████████████████| 30395/30395 [00:00<00:00, 58180.26it/s]
100%|█████████████████████████████████████████████████████████████████████████| 37996/37996 [00:00<00:00, 57568.94it/s]


In [35]:
train_data = annotation([news_train, news_mrc_train_4])
dev_data = annotation([news_test, news_mrc_test_4])
test_data = annotation([news_dev, news_mrc_dev_4])

In [36]:
save_jsonl('../../data/rerank/v2/news',train_data, 'train')
save_jsonl('../../data/rerank/v2/news',dev_data, 'dev')
save_jsonl('../../data/rerank/v2/news',test_data, 'test')

100%|███████████████████████████████████████████████████████████████████████| 287943/287943 [00:04<00:00, 59167.89it/s]
100%|█████████████████████████████████████████████████████████████████████████| 31995/31995 [00:00<00:00, 61015.53it/s]
100%|█████████████████████████████████████████████████████████████████████████| 39996/39996 [00:00<00:00, 55213.48it/s]


# KLUE MRC

In [37]:
def klue_make_instance(data:List):
    output = []
    for article in tqdm(data):
        title = article['title']
        context = article['context']
        question = article['question']
        answer = article['answers']['text']
        if article['is_impossible']==False:
            output.append(dict(title=title, context=context, question=question, answer=answer, label=1))
        else:
            output.append(dict(title=title, context=context, question=question, answer=answer, label=0))
    return output

In [38]:
from datasets import load_dataset
dataset = load_dataset('klue', 'mrc')	# klue dataset 중에 sts를 가져옴

Reusing dataset klue (C:\Users\User\.cache\huggingface\datasets\klue\mrc\1.0.0\e0fc3bc3de3eb03be2c92d72fd04a60ecc71903f821619cb28ca0e1e29e4233e)


  0%|          | 0/2 [00:00<?, ?it/s]

In [39]:
klue_mrc_train = klue_make_instance(dataset['train'])
klue_mrc_dev = klue_make_instance(dataset['validation'])

100%|█████████████████████████████████████████████████████████████████████████| 17554/17554 [00:01<00:00, 10407.20it/s]
100%|███████████████████████████████████████████████████████████████████████████| 5841/5841 [00:00<00:00, 10889.82it/s]


In [40]:
klue_mrc_train, klue_mrc_test = train_test_split(klue_mrc_train,test_size = 0.2, random_state=42, shuffle=True)

In [41]:
for name,data in zip(['train','dev','test'],[klue_mrc_train, klue_mrc_test, klue_mrc_dev]):
    dpr = []
    for i in data:
        if i['label']==1:
            dpr.append(dict(question=i['question'], answer = i['answer'], positive_ctxs=[dict(title=i['title'], context=' '.join(i['context'].split()))]))            
    path = '../../data/dpr/klue'
    os.makedirs(path, exist_ok=True)
    save_jsonl(path,dpr,name)

100%|███████████████████████████████████████████████████████████████████████████| 9654/9654 [00:00<00:00, 55254.33it/s]
100%|███████████████████████████████████████████████████████████████████████████| 2383/2383 [00:00<00:00, 56639.63it/s]
100%|███████████████████████████████████████████████████████████████████████████| 4008/4008 [00:00<00:00, 58849.29it/s]


In [42]:
train_data = annotation([klue_mrc_train])
dev_data = annotation([klue_mrc_test])
test_data = annotation([klue_mrc_dev])

In [44]:
save_jsonl('../../data/rerank/v2/klue',train_data,'train')
save_jsonl('../../data/rerank/v2/klue',dev_data,'dev')
save_jsonl('../../data/rerank/v2/klue',test_data,'test')

100%|█████████████████████████████████████████████████████████████████████████| 14043/14043 [00:00<00:00, 54582.23it/s]
100%|███████████████████████████████████████████████████████████████████████████| 3511/3511 [00:00<00:00, 60210.16it/s]
100%|███████████████████████████████████████████████████████████████████████████| 5841/5841 [00:00<00:00, 60267.97it/s]


# Total data eda

In [45]:
train_data[9]

{'title': '호라티우스',
 'context': '호라티우스의 출신 가문은 정확히 알려져 있지 않다. 아마 그의 아버지는 노예에서 해방된 자유신분(libertinus)으로서 로마 자유시민권을 가진 여인과 결혼한 것으로 보인다. 호라티우스는 어린시절부터 아버지로부터 세심한 교육을 받았으며, 기원전 45년에 당시 문화와 예술의 중심지인 고대 그리스의 아테네에 유학하여 고대 그리스 철학과 문학을 공부한다. 이 시기에 그는 역시 고대 그리스 문화를 사랑하는 마르쿠스 브루투스와 친교를 맺게되어 그를 따라 소아시아 지방에서 여러 전투에 참가한다. 기원전 약 40년을 전후로 호라티우스는 로마로 돌아와 젊은 문학자와 사귀면서, 특히 베르길리우스의 주선으로 당시의 로마의 문학 애호가이자 부호인 가이우스 마에케나스(Gaius Maecenas)에게 소개된다. 이 만남은 호라티우스가 사망할 때까지 깊은 우정관계로 발전한다. 특히 마이케나스는 호라티우스에게 기원전 32년 사비나 농장을 선물함으로써, 여기서 호라티우스는 경제적 어려움에서 완전히 해방되어 시 창작에 열중하게 된다.',
 'question': '호라티우스에게 베르길리우스를 소개해준 사람은?',
 'answer': ['가이우스 마에케나스(Gaius Maecenas)', 'Gaius Maecenas'],
 'label': 0}

In [46]:
path = '../../data/rerank/v2'
question_pool = defaultdict(list)
passage_pool = defaultdict(list)
total_data = []
for name in tqdm(['korquad','klue','book','ad','mrc','news']):
    cur_path = os.path.join(path, name)
    print(name)
    for i in ['train','dev','test']:
        final_path = os.path.join(cur_path, i)+'.jsonl'
        print(name + ' ' + i)
        data = load_jsonl(final_path)
        # for j in data:
        #     question_pool[j['question']].append(j)
        total_data+=data    

  0%|                                                                                            | 0/6 [00:00<?, ?it/s]

korquad
korquad train



0it [00:00, ?it/s][A
10738it [00:00, 106495.50it/s][A
22118it [00:00, 108833.65it/s][A
33758it [00:00, 111932.71it/s][A
54366it [00:00, 113097.45it/s][A


korquad dev



6041it [00:00, 115745.24it/s]


korquad test



5774it [00:00, 141056.04it/s]
 17%|██████████████                                                                      | 1/6 [00:00<00:03,  1.61it/s]

klue
klue train



0it [00:00, ?it/s][A
6295it [00:00, 58180.03it/s][A
14043it [00:00, 62953.76it/s][A


klue dev



3511it [00:00, 69369.96it/s]


klue test



5841it [00:00, 69679.01it/s]
 33%|████████████████████████████                                                        | 2/6 [00:01<00:01,  2.09it/s]

book
book train



0it [00:00, ?it/s][A
10928it [00:00, 108954.36it/s][A
22424it [00:00, 110315.77it/s][A
33563it [00:00, 110311.01it/s][A
44594it [00:00, 109781.30it/s][A
55746it [00:00, 110070.27it/s][A
67814it [00:00, 111196.35it/s][A
79111it [00:00, 111256.60it/s][A
90235it [00:00, 110748.04it/s][A
101424it [00:00, 110886.95it/s][A
113119it [00:01, 109877.57it/s][A
124645it [00:01, 110768.40it/s][A
136293it [00:01, 111416.32it/s][A
147606it [00:01, 111704.57it/s][A
158779it [00:01, 109946.50it/s][A
169779it [00:04, 12027.76it/s] [A
180072it [00:04, 16024.50it/s][A
191121it [00:04, 21606.00it/s][A
202334it [00:04, 28652.47it/s][A
212193it [00:04, 35564.45it/s][A
223405it [00:04, 44876.88it/s][A
234384it [00:04, 54543.89it/s][A
245060it [00:05, 63736.39it/s][A
256319it [00:05, 73484.68it/s][A
267578it [00:05, 82148.70it/s][A
278667it [00:05, 88970.62it/s][A
289987it [00:05, 94457.85it/s][A
301448it [00:05, 99346.66it/s][A
312680it [00:05, 102688.47it/s][A
325056it [00:05, 

book dev



0it [00:00, ?it/s][A
12172it [00:00, 120225.01it/s][A
24195it [00:00, 99313.80it/s] [A
34347it [00:00, 98049.45it/s][A
45000it [00:00, 100738.84it/s][A


book test



0it [00:00, ?it/s][A
12318it [00:00, 118006.65it/s][A
24119it [00:00, 116981.14it/s][A
36322it [00:00, 119127.13it/s][A
50000it [00:00, 120899.79it/s][A
 50%|██████████████████████████████████████████                                          | 3/6 [00:15<00:20,  6.77s/it]

ad
ad train



0it [00:00, ?it/s][A
10279it [00:00, 101882.84it/s][A
20468it [00:00, 101076.53it/s][A
30576it [00:00, 99811.18it/s] [A
40559it [00:00, 94174.36it/s][A
50017it [00:00, 87953.13it/s][A
58884it [00:00, 85061.04it/s][A
67924it [00:00, 86357.73it/s][A
77266it [00:00, 88328.41it/s][A
86803it [00:00, 90187.55it/s][A
96290it [00:01, 91364.32it/s][A
105855it [00:01, 92353.64it/s][A
115318it [00:01, 92768.35it/s][A
124609it [00:01, 92262.17it/s][A
133845it [00:01, 90853.35it/s][A
142941it [00:01, 89790.09it/s][A
151928it [00:01, 86998.21it/s][A
160747it [00:01, 86833.19it/s][A
170592it [00:01, 88906.41it/s][A
179899it [00:01, 89943.25it/s][A
189257it [00:02, 90614.86it/s][A
202151it [00:02, 90108.93it/s][A


ad dev



0it [00:00, ?it/s][A
9565it [00:00, 93964.06it/s][A
22462it [00:00, 93346.02it/s][A


ad test



0it [00:00, ?it/s][A
9336it [00:00, 92904.46it/s][A
28076it [00:00, 92503.65it/s][A
 67%|████████████████████████████████████████████████████████                            | 4/6 [00:18<00:10,  5.21s/it]

mrc
mrc train



0it [00:00, ?it/s][A
8293it [00:00, 80269.18it/s][A
16485it [00:00, 78115.73it/s][A
24300it [00:00, 77865.79it/s][A
32088it [00:00, 77561.02it/s][A
39845it [00:00, 77179.95it/s][A
47619it [00:00, 77146.09it/s][A
55579it [00:00, 77575.87it/s][A
63540it [00:00, 77954.67it/s][A
71424it [00:00, 77978.87it/s][A
79346it [00:01, 78207.06it/s][A
87326it [00:01, 78397.36it/s][A
95166it [00:01, 78152.97it/s][A
102982it [00:01, 77855.34it/s][A
110768it [00:01, 76769.37it/s][A
118448it [00:01, 70141.05it/s][A
125569it [00:01, 66810.33it/s][A
132928it [00:01, 68622.84it/s][A
140666it [00:01, 70798.72it/s][A
148670it [00:01, 73304.10it/s][A
156639it [00:02, 74970.45it/s][A
164480it [00:02, 75638.49it/s][A
172218it [00:02, 75730.75it/s][A
179984it [00:02, 75852.23it/s][A
187586it [00:02, 75759.92it/s][A
195174it [00:02, 75270.18it/s][A
202710it [00:02, 68454.25it/s][A
210802it [00:02, 71654.41it/s][A
219654it [00:02, 76168.84it/s][A
227821it [00:03, 77573.96it/s][A
2359

mrc dev



0it [00:00, ?it/s][A
7146it [00:00, 71050.97it/s][A
14350it [00:00, 70108.63it/s][A
22167it [00:00, 71060.47it/s][A
30125it [00:00, 73518.91it/s][A
39631it [00:00, 75564.03it/s][A


mrc test



0it [00:00, ?it/s][A
7186it [00:00, 71649.10it/s][A
14547it [00:00, 72691.35it/s][A
22026it [00:00, 73156.54it/s][A
29379it [00:00, 73057.77it/s][A
44035it [00:00, 76350.20it/s][A
 83%|██████████████████████████████████████████████████████████████████████              | 5/6 [00:27<00:06,  6.62s/it]

news
news train



0it [00:00, ?it/s][A
6486it [00:00, 61949.36it/s][A
12750it [00:00, 62096.76it/s][A
18960it [00:00, 60263.96it/s][A
24995it [00:00, 60169.39it/s][A
31901it [00:00, 62691.65it/s][A
38576it [00:00, 62256.81it/s][A
45370it [00:00, 63912.07it/s][A
53099it [00:00, 67968.89it/s][A
60867it [00:00, 70913.27it/s][A
68648it [00:01, 72801.60it/s][A
76586it [00:01, 74646.65it/s][A
84701it [00:01, 75307.33it/s][A
92256it [00:01, 75333.06it/s][A
99792it [00:01, 74850.07it/s][A
107279it [00:01, 74405.55it/s][A
114878it [00:01, 74746.25it/s][A
122756it [00:01, 75229.68it/s][A
130280it [00:01, 75098.84it/s][A
138451it [00:01, 75755.41it/s][A
146488it [00:02, 76652.59it/s][A
154373it [00:02, 77015.34it/s][A
162236it [00:02, 77115.35it/s][A
169948it [00:02, 73500.85it/s][A
177330it [00:02, 66294.29it/s][A
184098it [00:02, 63551.91it/s][A
191415it [00:02, 65861.95it/s][A
198470it [00:02, 66940.87it/s][A
205806it [00:02, 68471.10it/s][A
213709it [00:03, 71193.21it/s][A
221412

news dev



0it [00:00, ?it/s][A
6762it [00:00, 65278.37it/s][A
14427it [00:00, 71575.26it/s][A
22275it [00:00, 73431.23it/s][A
31995it [00:00, 72261.07it/s][A


news test



0it [00:00, ?it/s][A
7937it [00:00, 77084.00it/s][A
15646it [00:00, 76517.05it/s][A
23298it [00:00, 65350.74it/s][A
29991it [00:00, 58427.58it/s][A
39996it [00:00, 65216.71it/s][A
100%|████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:32<00:00,  5.40s/it]


In [48]:
import hashlib
def compute_hash(text):
    return hashlib.md5(text.encode()).hexdigest()

2092531


In [58]:
total_docs = {}
for i in tqdm(total_data):
    id = compute_hash(i['title']+' '+i['context'])
    total_docs[compute_hash(i['title']+' '+i['context'])]=dict(title=i['title'], context=i['context'], doc_id = c)

100%|█████████████████████████████████████████████████████████████████████| 2092531/2092531 [00:21<00:00, 96691.78it/s]


In [59]:
total_docs_ = {}
c = 0
for i,j in total_docs.items():
    c+=1
    total_docs_[c] = j
    total_docs_[c]['doc_id'] = c

In [61]:
total_docs_[1]

{'title': '회색늑대',
 'context': '북유럽 신화와 일본 신화에서 늑대는 신으로 가깝게 묘사되었다. 일본에서는 농부가 신사에서 늑대를 숭배하고 굴 근처에서 먹이를 주며, 야생 맷돼지와 사슴으로부터 작물을 보호하기 위해서 늑대에게 간청한다는 이야기가 있고, 북유럽 신화에서 나오는 펜리르 늑대는 로키의 아들로 묘사되었다. 다른 문화에서, 아일랜드 신화에서는 코르막 맥 에어트가 늑대로 묘사되며, 로마 신화에서는 암늑대가 로물루스와 레무스를 기르며 로마를 건국했다는 등 여러 기초 신화에서 늑대를 중심 인물로 두고 있다. 튀르크 신화,, 몽골 신화, 아이누 신화에서는 늑대가 자기 민족의 조상으로 나온 반면, 데나이나 민족은 늑대가 한 남성의 형태이며 자신의 형제인 것으로 믿는다. 늑대는 고대 그리스와 고대 로마 신화에서 아폴론 신과 늑대가 태양에 관련되어 있다고 설명하며 북유럽 신화에서는 태양의 신인 스콜이 늑대와 관련되는 등 일부 유라시아의 신화에서 늑대와 태양을 관련시켰다. 파웨니 민족의 창조 설화에 따르면 늑대는 죽음을 처음으로 겪은 동물이다. 북유럽 신화에서 볼바 마녀가 히로킨 거인과 힐다 거인을 늑대로 다스리는 것으로 묘사하고, 나바호 족은 늑대를 마녀로 생각하고 두려워하는 등 가끔 북유럽과 아메리카의 일부 원주민 신화에서 늑대는 모두 마법과 관련된 것으로 묘사된다. 마찬가지로, 칠코틴 족은 늑대가 정신 질환과 사망을 일으킬 수 있다고 생각했다.',
 'doc_id': 1}

In [64]:
len(total_docs_)

605688

In [None]:
question_pool = defaultdict(list)
for j in data_list:
    for i in j: 
        question = i['question']
        question_pool[question].append(dict(title=i['title'],context=i['context'],answer=i['answer'], label=i['label']))

In [None]:
print(len(question_pool))

2004545


In [None]:
print(np.percentile(list(map(len, question_pool.values())),[0,50,95,99,100]))

[ 1.  1.  1.  1. 51.]


In [None]:
no_answer = dict()
for i in question_pool:
    for j in question_pool[i]:
        if j['label']==0:
            no_answer[i]=question_pool[i]
            break
print(np.percentile(list(map(len, no_answer.values())),[0,50,95,99,100]))

[ 1.  1.  1.  1. 29.]


In [57]:
print(len(no_answer))

408411


In [56]:
sum(np.array(list(map(len, no_answer.values())))>=2)/len(no_answer)

0.003045951259882814

In [None]:
def annotation(data_list): # train, val, test 따로 따로 진행 요망.
    output = []
    total_contexts = set()
    c = 0
    label_1 = 0
    label_0 = 0
    label_0_0 = 0
    for i in tqdm(data_list,desc='make_context'):
        for j in i:
            c+=1
            total_contexts.add(j['title']+' '.join(j['context'].split())) # 전처리
    total_contexts = list(total_contexts)
    for i in data_list:
        for j in tqdm(i, desc='attach'):
            q = j['question']
            a = j['answer']
            if j['label']==1:
                pos = j['title']+' '+' '.join(j['context'].split())
                output.append(dict(question=q, answer=a, context=pos, label=1))
                label_1+=1
                while True:
                    tmp = random.choice(total_contexts)
                    # list가 아니면 None임 - is impossible임.
                    if type(j['answer'])==list:
                        for answer in j['answer']:
                            if answer in tmp:
                                break
                        else:
                            label_0+=1
                            output.append(dict(question=q, answer=a, context=tmp, label=0))
                            break 
                    else:
                        if a not in tmp:
                            label_0+=1
                            output.append(dict(question=q, answer=a, context=tmp, label=0))
                            break                        
            else:
                label_0_0 +=1
                neg = j['title']+' '+j['context']
                output.append(dict(question=q, answer=a, context=neg, label=2))
    print(f'positive 개수 : {label_1}, negative 개수 : {label_0}, is_impossible neagtive 개수 : {label_0_0}, total negative 개수 : {label_0+label_0_0}')
    
    return output