In [1]:
import tensorflow as tf
print(tf.__version__)

import tensorflow_hub as hub

import numpy as np
import os
import pandas as pd
import re

import gzip
import shutil

from tqdm import tqdm
import json
from nltk.tokenize import sent_tokenize

# Reduce logging output.
tf.logging.set_verbosity(tf.logging.ERROR)

1.14.0


In [2]:
def open_file(filename) :
    if filename.split('.')[-1] == "gz" :
        file = gzip.open(filename,'rt')
    else :
        file = open(filename, 'rt')
    return file

def collect_target_samples(filename) :
    
    num_lines = sum(1 for line in open_file(filename))    
    data = open_file(filename)
    
    # collected samples
    target_samples = []
    
    for i, line in tqdm(enumerate(data), total = num_lines) :
    
        jsondata = json.loads(line)
    
        if i == 0 :
            continue
            
        context = jsondata['context']
        context_len = len(context)
        
        # preventing sentence split according to . ! ? in answer texts
        for q in jsondata['qas'] :

            ans =  q['detected_answers'][0]['text']
            pos =  q['detected_answers'][0]['char_spans'][0][0]

            if '.' in ans :
                context = context[:pos] + re.sub('\.', '♬', context[pos : pos + len(ans)]) + context[pos + len(ans):]
                
            if '!' in ans :
                context = context[:pos] + re.sub('!', '♪', context[pos : pos + len(ans)]) + context[pos + len(ans):]
                
            if '?' in ans :
                context = context[:pos] + re.sub('\?', '♩', context[pos : pos + len(ans)]) + context[pos + len(ans):]

        # split passage to sentences
        sen_text = sent_tokenize(context)
        sen_pos = []
        
        sample = dict()
        
        for j in range(len(sen_text)) :

            pos = sum(sen_pos) 
            pos += len(sen_text[j])
                
            # adjusting sentence position because sent_tokenize automatically strips splited sentence texts
            while pos < context_len and ord(context[pos]) in [10, 32, 160] : 
                pos += 1

            sen_pos.append(pos - sum(sen_pos))

            # remove tag texts for improving sentence embedding quality
            sen_text[j] = re.sub('\[TLE\]', ' ', sen_text[j])
            sen_text[j] = re.sub('\[DOC\]', ' ', sen_text[j])
            sen_text[j] = re.sub('\[PAR\]', ' ', sen_text[j])
            
            #sen_text[j] = re.sub('<P>', ' ', sen_text[j])
            #sen_text[j] = re.sub('</P>', ' ', sen_text[j])

            # restore replaced characters
            sen_text[j] = re.sub('♬', '.', sen_text[j])
            sen_text[j] = re.sub('♪', '!', sen_text[j])
            sen_text[j] = re.sub('♩', '?', sen_text[j])

        sample['sentence_text'] = sen_text
        sample['sentence_pos'] = sen_pos
            
        questions = []
        for q in jsondata['qas'] :
            
            # only use first detected answer
            answer = q['detected_answers'][0]
            
            spans = []
            unique_sen = set()
            for j in range(len(answer['char_spans'])) :
                pos = answer['char_spans'][j][0]
                
                # find sentence index including target span
                idx = 0
                for k in range(len(sen_text)) :
                    if pos < (sum(sen_pos[:k+1])) :
                        idx = k
                        break
                
                # ignoring spans with duplicated sentence index
                if idx not in unique_sen :
                    spans.append((answer['char_spans'][j], idx))
                    unique_sen.add(idx)
              
            if len(spans) > 1 :
                questions.append({
                     'qid' : q['qid']
                    ,'question' : q['question']
                    ,'answer'   : answer['text']
                    ,'ans_spans': spans
                })
            
        if len(questions) > 0 :
            sample['question'] = questions
            target_samples.append(sample)

    return target_samples

In [3]:
import pickle

if False :

    train_files = [os.path.join("./MRQA-Shared-Task-2019/download_train", file) for file in os.listdir("./MRQA-Shared-Task-2019/download_train")]
    #dev_files = [os.path.join("./MRQA-Shared-Task-2019/download_in_domain_dev", file) for file in os.listdir("./MRQA-Shared-Task-2019/download_in_domain_dev")]
    #out_dev_files = [os.path.join("./MRQA-Shared-Task-2019/download_out_of_domain_dev", file) for file in os.listdir("./MRQA-Shared-Task-2019/download_out_of_domain_dev")]

    files = train_files
    print(files)
    
    all_target_samples = dict()
    for file in files :
        print(file)
        target_samples = collect_target_samples(file)
        questions = []
        for l in target_samples :
            questions = questions + l['question']
        print("Num. collected samples :", len(questions))
        
        if len(questions) > 0 :
            all_target_samples[file] = target_samples

    with open("all_target_samples_in_domain_train.pickle", 'wb') as handle:
        pickle.dump(all_target_samples, handle)
    
else :
    
    with open("all_target_samples_in_domain_train.pickle", 'rb') as handle:
        all_target_samples = pickle.load(handle)
        
print(all_target_samples.keys())

dict_keys(['./MRQA-Shared-Task-2019/download_train/HotpotQA.jsonl.gz', './MRQA-Shared-Task-2019/download_train/NewsQA.jsonl.gz', './MRQA-Shared-Task-2019/download_train/SQuAD.jsonl.gz', './MRQA-Shared-Task-2019/download_train/TriviaQA.jsonl.gz', './MRQA-Shared-Task-2019/download_train/SearchQA.jsonl.gz'])


In [5]:
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.feature_extraction.text import TfidfVectorizer

def correct_target_samples(data) :
    
    qid = []
    span_num = []
    spans = []
    texts = []
    ans = []
    for d in tqdm(data) :
        sentence_text = d['sentence_text']
        for q in d['question'] :
            texts.append(q['question'])
            spans.append(q['ans_spans'])
            span_texts = [sentence_text[t[1]] for t in q['ans_spans']]
            texts = texts + span_texts      
            qid.append(q['qid'])
            ans.append(q['answer'])
            span_num.append(len(span_texts))
            
    print("Num. quesitons :", len(qid))
    print("Num. texts     :", len(texts))

    # Import the Universal Sentence Encoder's TF Hub module
    module_url = "https://tfhub.dev/google/universal-sentence-encoder/2" #@param ["https://tfhub.dev/google/universal-sentence-encoder/2", "https://tfhub.dev/google/universal-sentence-encoder-large/3"]
    embed = hub.Module(module_url)

    session = tf.Session()
    session.run([tf.global_variables_initializer(), tf.tables_initializer()])
    
    vectorizer = TfidfVectorizer().fit(texts)
    
    correct_spans = []
    batch_size = 1000
    for i in tqdm(range(0, len(qid), batch_size)) :
    
        prev_text_num = sum(span_num[:i]) + i
        next_text_num = sum(span_num[i:i+batch_size]) + batch_size
    
        batch_text = texts[prev_text_num : prev_text_num + next_text_num]

        text_embed = session.run(embed(batch_text))
        text_vect = vectorizer.transform(batch_text)

        for j in range(len(span_num[i:i+batch_size])) :
            
            text_idx = sum(span_num[i:i+j]) + j

            # get sentence embedding similarity
            q_embed = text_embed[text_idx].reshape(1, -1)
            s_embed = text_embed[text_idx+1:text_idx+1+span_num[i+j]].reshape(span_num[i+j], -1)
            embed_sim = cosine_similarity(q_embed, s_embed)

            # get tfidf vector similarity
            q_vect = text_vect[text_idx].reshape(1, -1)
            s_vect = text_vect[text_idx+1:text_idx+1+span_num[i+j]].reshape(span_num[i+j], -1)
            vect_sim = cosine_similarity(q_vect, s_vect)
            
            # blend similarity
            sim = (embed_sim + vect_sim) / 2
        
            most_sim_idx = np.argsort(sim)[0][-1]
            if most_sim_idx != 0 :
                correct_spans.append({
                     "qid" : qid[i+j]
                    ,"question" : texts[text_idx]
                    ,"answer"   : ans[i+j]
                    ,"origin_span" : (texts[text_idx + 1] , spans[i+j][0])
                    ,"revise_span" : (texts[text_idx + 1 + most_sim_idx] , spans[i+j][most_sim_idx])
                })
        
    return correct_spans

for file in all_target_samples.keys() :
    print(file)
    correct_spans = correct_target_samples(all_target_samples[file])
    print("Num. revised spans :", len(correct_spans))
    if len(correct_spans) > 0 :
        with open(file + "_spans.pickle", 'wb') as handle:
            pickle.dump(correct_spans, handle)

all_target_samples.keys()

  2%|▏         | 2094/101413 [00:00<00:04, 20937.02it/s]

./MRQA-Shared-Task-2019/download_train/SearchQA.jsonl.gz


100%|██████████| 101413/101413 [20:44<00:00, 81.47it/s] 


Num. quesitons : 101413
Num. texts     : 854166


100%|██████████| 102/102 [58:36<00:00, 52.75s/it]


Num. revised spans : 80665


In [6]:
for file in all_target_samples.keys() :
    print(file)
    
    try :
        with open(file + "_spans.pickle", 'rb') as handle:
            correct_spans = pickle.load(handle)
    except :
        continue
    
    # make qid index
    qid_index = dict()
    for i, row in enumerate(correct_spans) :
        qid_index[row['qid']] = i
    
    path = '/'.join(file.split('/')[:-1])
    filename = file.split('/')[-1].split('.')[0]
    new_file = open(os.path.join(path, filename + "_revised.jsonl"), "wt")
    
    num_lines = sum(1 for line in open_file(file))    
    data = open_file(file)
    for i, line in tqdm(enumerate(data), total = num_lines) :
    
        jsondata = json.loads(line)
    
        if i == 0 :
            new_file.write(json.dumps(jsondata) + '\n')
            continue

        for j, q in enumerate(jsondata['qas']) :
            
            if q['qid'] in qid_index :
                
                revised_idx = 0
                revised_span = correct_spans[qid_index[q['qid']]]['revise_span'][1][0]
                char_spans = q['detected_answers'][0]['char_spans']
                for k in range(1, len(char_spans)) :
                    if char_spans[k][0] == revised_span[0] and char_spans[k][1] == revised_span[1] :
                        revised_idx = k
                        break
                        
                jsondata['qas'][j]['detected_answers'][0]['token_spans'] = [jsondata['qas'][j]['detected_answers'][0]['token_spans'][revised_idx]]
                
        new_file.write(json.dumps(jsondata) + '\n')
    
    new_file.close()
    
    with open(os.path.join(path, filename + "_revised.jsonl"), 'rt') as f_in:
        with gzip.open(os.path.join(path, filename + "_revised.jsonl") + '.gz', 'wt') as f_out:
            shutil.copyfileobj(f_in, f_out)

./MRQA-Shared-Task-2019/download_train/HotpotQA.jsonl.gz


100%|██████████| 72929/72929 [00:15<00:00, 4800.23it/s]


./MRQA-Shared-Task-2019/download_train/NewsQA.jsonl.gz


100%|██████████| 11429/11429 [00:08<00:00, 1302.17it/s]


./MRQA-Shared-Task-2019/download_train/SQuAD.jsonl.gz


100%|██████████| 18886/18886 [00:04<00:00, 4357.72it/s]


./MRQA-Shared-Task-2019/download_train/TriviaQA.jsonl.gz


100%|██████████| 61689/61689 [01:01<00:00, 1010.49it/s]


./MRQA-Shared-Task-2019/download_train/SearchQA.jsonl.gz


100%|██████████| 117385/117385 [01:49<00:00, 1070.41it/s]


In [5]:
import pandas as pd

data = []
for file in correct_result.keys() :
    for s in correct_result[file] :
        data.append([
             file
            ,s['qid']
            ,s['question']
            ,s['answer']
            ,s['origin_span']
            ,s['revise_span']
        ])

df = pd.DataFrame(data)
df.columns = ['dataset', 'qid', 'question', 'answer', 'origin_span', 'revise_span']
df.to_csv("revise_spans_with_tfidf_out_domain.csv", index = False)

In [13]:
def test(filename) :
    
    num_lines = sum(1 for line in open_file(filename))    
    data = open_file(filename)
    
    # collected samples
    target_samples = []
    
    for i, line in tqdm(enumerate(data), total = num_lines) :
    
        jsondata = json.loads(line)

In [14]:
for file in all_target_samples.keys() :
    path = '/'.join(file.split('/')[:-1])
    filename = file.split('/')[-1].split('.')[0]
    print(filename)
    test(os.path.join(path, filename + "_revised.jsonl.gz"))

HotpotQA


100%|██████████| 72929/72929 [00:06<00:00, 10452.68it/s]


NewsQA


100%|██████████| 11429/11429 [00:05<00:00, 2005.08it/s]


SQuAD


100%|██████████| 18886/18886 [00:01<00:00, 9657.75it/s]


TriviaQA


100%|██████████| 61689/61689 [00:35<00:00, 1746.00it/s]


SearchQA


100%|██████████| 117385/117385 [01:06<00:00, 1761.26it/s]
