In [12]:
import pandas as pd
import os
import json
import pickle
from tqdm.auto import tqdm
import numpy as np
from typing import List
import re
from sklearn.model_selection import train_test_split
from utils.utils import *
import random
from collections import Counter, defaultdict

seed_everything(42)

# KorQuAD

In [13]:
def make_instance(data:List):
    output = []
    for article in tqdm(data):
        if article.get('title') is not None:
            title = article['title']
        elif article.get('doc_title') is not None:
            title = article['doc_title']
        else:
            raise Exception('title error')
        title = re.sub('_',' ',title)
        for paragraph in article['paragraphs']:
            context = paragraph['context']
            #contexts.append(dict(title=title, context=context))
            for qas in paragraph['qas']:
                question = qas['question']
                if type(qas['answers'])==list:
                    answer = [a['text'] for a in qas['answers']]
                else:
                    answer = qas['answers']['text']
                if qas.get('is_impossible') is not None:
                    if qas['is_impossible']==True:
                        output.append(dict(title=title, context=context, question=question, answer=answer, label=0))
                    else:
                        output.append(dict(title=title, context=context, question=question, answer=answer, label=1))
                else:
                    output.append(dict(title=title, context=context, question=question, answer=answer, label=1))
    return output

In [14]:
def no_answer_make_instance(data:List):
    output = []
    for article in tqdm(data):
        if article.get('title') is not None:
            title = article['title']
        elif article.get('doc_title') is not None:
            title = article['doc_title']
        else:
            raise Exception('title error')
        title = re.sub('_',' ',title)
        for paragraph in article['paragraphs']:
            context = paragraph['context']
            for qas in paragraph['qas']:
                question = qas['question']
                answer = qas['answers'] if qas.get('answers') is not None else None
                output.append(dict(title=title, context=context, question=question, answer=answer, label=0))
    return output

In [15]:
def annotation(data_list): # train, val, test 따로 따로 진행 요망.
    output = []
    total_contexts = set()
    c = 0
    label_1 = 0
    label_0 = 0
    label_0_0 = 0
    for i in tqdm(data_list,desc='make_context'):
        for j in i:
            c+=1
            total_contexts.add(j['title']+' '.join(j['context'].split())) # 전처리
    total_contexts = list(total_contexts)
    for i in data_list:
        for j in tqdm(i, desc='attach'):
            q = j['question']
            a = j['answer']
            if j['label']==1:
                pos = j['title']+' '+' '.join(j['context'].split())
                output.append(dict(question=q, answer=a, context=pos, label=1))
                label_1+=1
                while True:
                    tmp = random.choice(total_contexts)
                    # list가 아니면 None임 - is impossible임.
                    if type(j['answer'])==list:
                        for answer in j['answer']:
                            if answer in tmp:
                                break
                        else:
                            label_0+=1
                            output.append(dict(question=q, answer=a, context=tmp, label=0))
                            break 
                    else:
                        if a not in tmp:
                            label_0+=1
                            output.append(dict(question=q, answer=a, context=tmp, label=0))
                            break                        
            else:
                label_0_0 +=1
                neg = j['title']+' '+j['context']
                output.append(dict(question=q, answer=a, context=neg, label=2))
    print(f'positive 개수 : {label_1}, negative 개수 : {label_0}, is_impossible neagtive 개수 : {label_0_0}, total negative 개수 : {label_0+label_0_0}')
    
    return output

In [16]:
korquad_train = json.load(open('../../data/MRC_data/KorQuAD 1.0/KorQuAD_v1.0_train.json','r'))
korquad_dev = json.load(open('../../data/MRC_data/KorQuAD 1.0/KorQuAD_v1.0_dev.json','r'))

In [17]:
korquad_train = make_instance(korquad_train['data'])
korquad_dev = make_instance(korquad_dev['data'])

100%|███████████████████████████████████████████████████████████████████████████| 1420/1420 [00:00<00:00, 23329.71it/s]
100%|██████████████████████████████████████████████████████████████████████████████| 140/140 [00:00<00:00, 7004.85it/s]


In [18]:
from collections import Counter
def print_statistics(data, name='data'):
    print(name)
    labels = []
    for i in tqdm(data):
        labels.append(i['label'])
    return Counter(labels)

In [19]:
print_statistics(korquad_train+korquad_dev, 'korquad')

korquad


100%|██████████████████████████████████████████████████████████████████████| 66181/66181 [00:00<00:00, 12321698.91it/s]


Counter({1: 66181})

In [20]:
print(len(korquad_train))

60407


In [21]:
print(len(korquad_dev))

5774


# 기계독해 - dev 존재 x

In [22]:
mrc_train_1 = json.load(open('../../data/MRC_data/기계독해/기계독해분야/01.Normal/ko_nia_normal_squad_all.json','r',encoding='utf-8'))
mrc_train_2 = json.load(open('../../data/MRC_data/기계독해/기계독해분야/02.NoAnswer/ko_nia_noanswer_squad_all.json','r',encoding='utf-8')) # no answer
mrc_train_3 = json.load(open('../../data/MRC_data/기계독해/기계독해분야/03.Clue_/ko_nia_clue0529_squad_all.json','r',encoding='utf-8'))

In [23]:
mrc_train_1 = make_instance(mrc_train_1['data'])
mrc_train_2 = no_answer_make_instance(mrc_train_2['data'])
mrc_train_3 = make_instance(mrc_train_3['data'])

100%|█████████████████████████████████████████████████████████████████████████| 47314/47314 [00:00<00:00, 50737.58it/s]
100%|████████████████████████████████████████████████████████████████████████| 20030/20030 [00:00<00:00, 329391.30it/s]
100%|████████████████████████████████████████████████████████████████████████| 34500/34500 [00:00<00:00, 238213.08it/s]


In [24]:
print_statistics(mrc_train_1+mrc_train_2+mrc_train_3, 'mrc')

mrc


100%|█████████████████████████████████████████████████████████████████████| 440332/440332 [00:00<00:00, 3628560.06it/s]


Counter({1: 340088, 0: 100244})

In [25]:
mrc = mrc_train_1+mrc_train_2+mrc_train_3

# 도서자료 기계독해 - dev 존재

In [26]:
book_mrc_train = json.load(open('../../data/MRC_data/도서자료 기계독해/Training/도서_train/도서_220419_add/도서_220419_add.json','r',encoding='utf-8'))
book_mrc_dev = json.load(open('../../data/MRC_data/도서자료 기계독해/Validation/도서_valid/도서.json','r',encoding='utf-8'))

In [27]:
book_mrc_train = make_instance(book_mrc_train['data'])
book_mrc_dev = make_instance(book_mrc_dev['data'])

100%|████████████████████████████████████████████████████████████████████████████| 5368/5368 [00:02<00:00, 2011.73it/s]
100%|███████████████████████████████████████████████████████████████████████████| 1994/1994 [00:00<00:00, 40540.00it/s]


In [28]:
book = book_mrc_train+book_mrc_dev

In [29]:
print_statistics(book_mrc_train+book_mrc_dev, 'book')

book


100%|█████████████████████████████████████████████████████████████████████| 950000/950000 [00:00<00:00, 4716366.00it/s]


Counter({1: 665000, 0: 285000})

# 행정문서 기계독해

In [30]:
ad_train_1 = json.load(open('../../data/MRC_data/행정 문서 대상 기계독해/01.데이터/1.Training/라벨링데이터/TL_multiple_choice/TL_multiple_choice.json','r',encoding='utf-8'))
ad_train_2 = json.load(open('../../data/MRC_data/행정 문서 대상 기계독해/01.데이터/1.Training/라벨링데이터/TL_span_extraction/TL_span_extraction.json','r',encoding='utf-8'))
ad_train_3 = json.load(open('../../data/MRC_data/행정 문서 대상 기계독해/01.데이터/1.Training/라벨링데이터/TL_span_extraction_how/TL_span_extraction_how.json','r',encoding='utf-8'))
ad_train_4 = json.load(open('../../data/MRC_data/행정 문서 대상 기계독해/01.데이터/1.Training/라벨링데이터/TL_unanswerable/TL_unanswerable.json','r',encoding='utf-8')) # NO
ad_train_5 = json.load(open('../../data/MRC_data/행정 문서 대상 기계독해/01.데이터/1.Training/라벨링데이터/TL_text_entailment/TL_text_entailment.json','r',encoding='utf-8'))

ad_dev_1 = json.load(open('../../data/MRC_data/행정 문서 대상 기계독해/01.데이터/2.Validation/라벨링데이터/VL_multiple_choice/VL_multiple_choice.json','r',encoding='utf-8'))
ad_dev_2 = json.load(open('../../data/MRC_data/행정 문서 대상 기계독해/01.데이터/2.Validation/라벨링데이터/VL_span_extraction/VL_span_extraction.json','r',encoding='utf-8'))
ad_dev_3 = json.load(open('../../data/MRC_data/행정 문서 대상 기계독해/01.데이터/2.Validation/라벨링데이터/VL_span_extraction_how/VL_span_extraction_how.json','r',encoding='utf-8'))
# NO
ad_dev_4 = json.load(open('../../data/MRC_data/행정 문서 대상 기계독해/01.데이터/2.Validation/라벨링데이터/VL_unanswerable/VL_unanswerable.json','r',encoding='utf-8'))
ad_dev_5 = json.load(open('../../data/MRC_data/행정 문서 대상 기계독해/01.데이터/2.Validation/라벨링데이터/VL_text_entailment/VL_text_entailment.json','r',encoding='utf-8'))

In [31]:
ad_train_1 =make_instance(ad_train_1['data'])
ad_train_2 =make_instance(ad_train_2['data'])
ad_train_3 =make_instance(ad_train_3['data'])
ad_train_4 =no_answer_make_instance(ad_train_4['data'])
ad_train_5 =make_instance(ad_train_5['data'])

ad_dev_1 = make_instance(ad_dev_1['data'])
ad_dev_2 = make_instance(ad_dev_2['data'])
ad_dev_3 = make_instance(ad_dev_3['data'])
ad_dev_4 = no_answer_make_instance(ad_dev_4['data'])
ad_dev_5 = make_instance(ad_dev_5['data'])

100%|████████████████████████████████████████████████████████████████████████| 15085/15085 [00:00<00:00, 293293.70it/s]
100%|████████████████████████████████████████████████████████████████████████| 63932/63932 [00:00<00:00, 459059.97it/s]
100%|████████████████████████████████████████████████████████████████████████| 29074/29074 [00:00<00:00, 511860.76it/s]
100%|██████████████████████████████████████████████████████████████████████████| 9828/9828 [00:00<00:00, 398792.83it/s]
100%|████████████████████████████████████████████████████████████████████████| 16872/16872 [00:00<00:00, 473137.95it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████| 2813/2813 [00:00<?, ?it/s]
100%|████████████████████████████████████████████████████████████████████████| 11985/11985 [00:00<00:00, 346743.09it/s]
100%|█████████████████████████████████████████████████████████████████████████| 5458/5458 [00:00<00:00, 2053876.84it/s]
100%|███████████████████████████████████

In [32]:
ad = ad_train_1+ad_train_2+ad_train_3+ad_train_4+ad_train_5+ad_dev_1+ad_dev_2+ad_dev_3+ad_dev_4+ad_dev_5

In [33]:
print_statistics(ad_train_1+ad_train_2+ad_train_3+ad_train_4+ad_train_5+ad_dev_1+ad_dev_2+ad_dev_3+ad_dev_4+ad_dev_5, 'book')

book


100%|█████████████████████████████████████████████████████████████████████| 252689/252689 [00:00<00:00, 3910526.97it/s]


Counter({1: 234239, 0: 18450})

# 뉴스 기사 기계독해

In [34]:
news_mrc_train_1 = json.load(open('../../data/MRC_data/뉴스 기사 기계독해/01.데이터/1.Training/라벨링데이터/TL_text_entailment/TL_text_entailment.json','r',encoding='utf-8'))
news_mrc_train_2 = json.load(open('../../data/MRC_data/뉴스 기사 기계독해/01.데이터/1.Training/라벨링데이터/TL_span_extraction/TL_span_extraction.json','r',encoding='utf-8'))
news_mrc_train_3 = json.load(open('../../data/MRC_data/뉴스 기사 기계독해/01.데이터/1.Training/라벨링데이터_221115_add/TL_span_inference/TL_span_inference.json','r',encoding='utf-8'))
news_mrc_train_4 = json.load(open('../../data//MRC_data/뉴스 기사 기계독해/01.데이터/1.Training/라벨링데이터/TL_unanswerable/TL_unanswerable.json','r',encoding='utf-8'))
news_mrc_dev_1 = json.load(open('../../data/MRC_data/뉴스 기사 기계독해/01.데이터/2.Validation/라벨링데이터/VL_span_extraction/VL_span_extraction.json','r',encoding='utf-8'))
news_mrc_dev_2 = json.load(open('../../data/MRC_data/뉴스 기사 기계독해/01.데이터/2.Validation/라벨링데이터/VL_span_inference/VL_span_inference.json','r',encoding='utf-8'))
news_mrc_dev_3 = json.load(open('../../data/MRC_data/뉴스 기사 기계독해/01.데이터/2.Validation/라벨링데이터/VL_text_entailment/VL_text_entailment.json','r',encoding='utf-8'))
news_mrc_dev_4 = json.load(open('../../data//MRC_data/뉴스 기사 기계독해/01.데이터/2.Validation/라벨링데이터/VL_unanswerable/VL_unanswerable.json','r',encoding='utf-8'))

In [35]:
news_mrc_train_1 = make_instance(news_mrc_train_1['data'])
news_mrc_train_2 = make_instance(news_mrc_train_2['data'])
news_mrc_train_3 = make_instance(news_mrc_train_3['data'])
news_mrc_train_4 = no_answer_make_instance(news_mrc_train_4['data'])
news_mrc_dev_1 = make_instance(news_mrc_dev_1['data'])
news_mrc_dev_2 = make_instance(news_mrc_dev_2['data'])
news_mrc_dev_3 = make_instance(news_mrc_dev_3['data'])
news_mrc_dev_4 = no_answer_make_instance(news_mrc_dev_4['data'])

100%|████████████████████████████████████████████████████████████████████████| 24009/24009 [00:00<00:00, 327413.86it/s]
100%|██████████████████████████████████████████████████████████████████████| 111967/111967 [00:00<00:00, 387147.00it/s]
100%|████████████████████████████████████████████████████████████████████████| 15992/15992 [00:00<00:00, 363177.81it/s]
100%|██████████████████████████████████████████████████████████████████████████| 8000/8000 [00:00<00:00, 344897.95it/s]
100%|████████████████████████████████████████████████████████████████████████| 13997/13997 [00:00<00:00, 300086.25it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████| 1999/1999 [00:00<?, ?it/s]
100%|██████████████████████████████████████████████████████████████████████████| 3001/3001 [00:00<00:00, 186084.18it/s]
100%|██████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 500036.24it/s]


In [36]:
news = news_mrc_train_1+news_mrc_train_2+news_mrc_train_3+news_mrc_train_4+news_mrc_dev_1+news_mrc_dev_2+news_mrc_dev_3+news_mrc_dev_4

In [37]:
print_statistics(news_mrc_train_1+news_mrc_train_2+news_mrc_train_3+news_mrc_train_4+news_mrc_dev_1+news_mrc_dev_2+news_mrc_dev_3+news_mrc_dev_4, 'book')

book


100%|█████████████████████████████████████████████████████████████████████| 359934/359934 [00:00<00:00, 3662502.68it/s]


Counter({1: 341934, 0: 18000})

# KLUE MRC

In [38]:
def klue_make_instance(data:List):
    output = []
    for article in tqdm(data):
        title = article['title']
        context = article['context']
        question = article['question']
        answer = article['answers']['text']
        if article['is_impossible']==False:
            output.append(dict(title=title, context=context, question=question, answer=answer, label=1))
        else:
            output.append(dict(title=title, context=context, question=question, answer=answer, label=0))
    return output

In [39]:
from datasets import load_dataset
dataset = load_dataset('klue', 'mrc')	# klue dataset 중에 sts를 가져옴

Reusing dataset klue (C:\Users\User\.cache\huggingface\datasets\klue\mrc\1.0.0\e0fc3bc3de3eb03be2c92d72fd04a60ecc71903f821619cb28ca0e1e29e4233e)


  0%|          | 0/2 [00:00<?, ?it/s]

In [40]:
klue_mrc_train = klue_make_instance(dataset['train'])
klue_mrc_dev = klue_make_instance(dataset['validation'])

100%|██████████████████████████████████████████████████████████████████████████| 17554/17554 [00:01<00:00, 9914.22it/s]
100%|████████████████████████████████████████████████████████████████████████████| 5841/5841 [00:00<00:00, 9431.31it/s]


In [41]:
print_statistics(klue_mrc_train+klue_mrc_dev, 'book')

book


100%|███████████████████████████████████████████████████████████████████████| 23395/23395 [00:00<00:00, 4237777.68it/s]


Counter({1: 16045, 0: 7350})

In [42]:
klue = klue_mrc_train+klue_mrc_dev

# total

In [43]:
total = klue + mrc + ad + news + book

In [44]:
import hashlib
def compute_hash(text):
    return hashlib.md5(text.encode()).hexdigest()

In [45]:
total_docs = {}
for i in tqdm(total):
    title = i['title']
    context = i['context']
    id = compute_hash(title+' '+context)
    total_docs[id] = dict(title=title, context=context, pos=[], neg=[])

100%|████████████████████████████████████████████████████████████████████| 2026350/2026350 [00:13<00:00, 150420.14it/s]


In [46]:
for i in tqdm(total):
    id = compute_hash(i['title']+' '+i['context'])
    if i['label']==1:
        total_docs[id]['pos'].append(i['question'])
    elif i['label']==0:
        total_docs[id]['neg'].append(i['question'])

100%|████████████████████████████████████████████████████████████████████| 2026350/2026350 [00:10<00:00, 185529.50it/s]


In [47]:
chk = 0
for i in total_docs.values():
    if len(i['pos'])>1 and len(i['neg'])>1:
        chk+=1

In [48]:
filtering_data = []
for i in total_docs.values():
    if len(i['pos'])>1 and len(i['neg'])>1:
        filtering_data.append(i)

# save

In [49]:
len(filtering_data)

56072

In [50]:
train,dev = train_test_split(filtering_data, test_size = 2000, shuffle = True, random_state = 42)
train,test = train_test_split(train, test_size = 2000, shuffle = True, random_state = 42)

In [51]:
def make_dataset(data):
    output = []
    pos = 0 
    neg = 0
    for i in tqdm(data):
        for j in i['pos']:
            pos+=1
            output.append(dict(title = i['title'], context = i['context'], question=j, label = 1))
        for j in i['neg']:
            neg+=1
            output.append(dict(title = i['title'], context = i['context'], question=j, label = 0))
    print(f'positive - {pos}')
    print(f'neg - {neg}')
    return output

In [52]:
train_dataset = make_dataset(train)

100%|████████████████████████████████████████████████████████████████████████| 52072/52072 [00:00<00:00, 221556.78it/s]

positive - 268186
neg - 148341





In [53]:
dev_dataset = make_dataset(dev)
test_dataset = make_dataset(test)

100%|██████████████████████████████████████████████████████████████████████████| 2000/2000 [00:00<00:00, 691786.90it/s]


positive - 10324
neg - 5697


100%|███████████████████████████████████████████████████████████████████████████| 2000/2000 [00:00<00:00, 91671.76it/s]

positive - 10304
neg - 5674





In [54]:
save_jsonl('D:/jupyter_notebook/data/q_filtering_data',train_dataset,'train')
save_jsonl('D:/jupyter_notebook/data/q_filtering_data',dev_dataset,'dev')
save_jsonl('D:/jupyter_notebook/data/q_filtering_data',test_dataset,'test')

100%|██████████████████████████████████████████████████████████████████████| 416527/416527 [00:03<00:00, 104860.96it/s]
100%|████████████████████████████████████████████████████████████████████████| 16021/16021 [00:00<00:00, 112161.65it/s]
100%|████████████████████████████████████████████████████████████████████████| 15978/15978 [00:00<00:00, 101118.19it/s]
