In [1]:
import pandas as pd
import os
from collections import defaultdict
import pickle
import json
import itertools
# Set ipython's max row display
pd.set_option('display.max_rows', None)
pd.set_option('display.max_columns', None)
pd.set_option('display.width', None)
pd.set_option('display.max_colwidth', -1)

  # This is added back by InteractiveShellApp.init_path()


## Get your data files (for this example we used a fusion dataset)
    Use this notebook as an example when preparing your data to run QA-Align model on it. This notebook is not runnable.

In [7]:
train = pd.read_csv("../../../MultiDocumentFusion/data/train/extended_doc-pyr_train.csv")
dev = pd.read_csv("../../../MultiDocumentFusion/data/train/extended_doc-pyr_dev.csv")
test = pd.read_csv("../../../MultiDocumentFusion/data/train/extended_doc-pyr_test.csv")

In [107]:
train2[train2.abs_sent_id == 'TAC2008~!~D0834-B-CDEF~!~12']

Unnamed: 0,scu_label,abs_sent_id,text,scu_text_spans,scu_spans,start_char_idx,end_char_idx,prev_text,prev_text_abs_sent_id,scu_uid,abs_scu_id,year,topic,source,has_verb,scu_span_ratio,span_len,w_overlap,source_short_long,text_len,label_len,min_w_overlap,label_to_scu_ratio,one_contributor,one_short_or_long_source,no_source_span_ratio,prev_scu_label,document,weight,stripped_id
0,Trial was named Kitzmiller vs. Dover Area School District,TAC2008~!~D0834-B-CDEF~!~12,For three weeks plaintiffs in the Kitzmiller vs. Dover Area School District trial presented expert witnesses who testified that intelligent design does not meet the definition of science because it cannot be proved or disproved.,['Kitzmiller vs. Dover Area School District trial'],"[(34, 81)]",786.0,1014.0,-----------------,TAC2008~!~D0834-B-CDEF~!~1,6,TAC2008~!~D0834-B-CDEF~!~6~!~SCU,2008,D0834-B-CDEF,reference_summary_sentence,True,0.2,7.0,7,True,35,9,7,1.285714,True,False,False,,,,2008~!~D0834-B-CDEF~!~6~!~SCU
2,Plaintiffs argued that ID is not [testable] science,TAC2008~!~D0834-B-CDEF~!~12,For three weeks plaintiffs in the Kitzmiller vs. Dover Area School District trial presented expert witnesses who testified that intelligent design does not meet the definition of science because it cannot be proved or disproved.,"['plaintiffs', 'presented expert witnesses who testified that intelligent design does not meet the definition of science because it cannot be proved or disproved']","[(16, 26), (82, 227)]",786.0,1014.0,-----------------,TAC2008~!~D0834-B-CDEF~!~1,7,TAC2008~!~D0834-B-CDEF~!~7~!~SCU,2008,D0834-B-CDEF,reference_summary_sentence,True,0.657143,23.0,2,True,35,8,2,0.533333,False,False,False,,,,2008~!~D0834-B-CDEF~!~7~!~SCU


In [29]:
def get_df_sents(df):
    df_sents = df['abs_sent_id'].unique()
    print("# unique sents: ",len(df_sents))
    return df_sents

In [28]:
import spacy
nlp = spacy.load('en_core_web_sm')
def tokenize_sentence(sentence):
    tokens = [tok.text for tok in nlp(sentence)]
    return tokens

In [27]:
def create_and_save_qasrl_input(data_sents, orig_data):
    sents = pd.DataFrame(data_sents,columns=['qasrl_id'])
    sents['sentence'] = sents['qasrl_id'].apply(lambda x: orig_data[orig_data.abs_sent_id == x]['text'].iloc[0])
    sents['tokens'] = sents['sentence'].apply(lambda x: " ".join(tokenize_sentence(x)))
    sents.drop_duplicates(['qasrl_id'],inplace=True)
    return sents

In [26]:
import json
def create_json_qasrl_input(df, output):
    with open(output, "w") as f:
        for i,row in df.iterrows():
            sen = {"qasrl_id":row['qasrl_id'], "sentence": row['sentence'], "tokens":row['tokens']}
            json.dump(sen, f)
            f.write('\n')

In [25]:
def remove_dups(data):
    inds_to_remove = []
    c = 0
    num_dups = []
    for i,df in data.groupby("qasrl_id"):
        questions_so_far = defaultdict()
        found=False
        for j,row in df.iterrows():
            if row['question'] in questions_so_far:
                wrs = set(row['answer'].split())
                prev_wrs = set(questions_so_far[row['question']].split())
                iou = len(wrs.intersection(prev_wrs))/len(wrs.union(prev_wrs))
                if iou == 1:
                    print("no")
                    inds_to_remove.append(j)
            else:
                questions_so_far[row['question']] = row['answer']
    return inds_to_remove

### Getting sentence data into format QASRL parser can accept

In [30]:
train_sents = get_df_sents(train)
dev_sents = get_df_sents(dev)
test_sents = get_df_sents(test)

# unique sents:  9066
# unique sents:  1592
# unique sents:  4767


In [56]:
final_train_sents = create_and_save_qasrl_input(train_sents,train)
final_train_sents.to_csv("../../data/predict/extended_fusion/train_extended_fusion_sents.csv", index=False)
create_json_qasrl_input(final_train_sents, "../../data/predict/extended_fusion/train_extended_fusion_sents.json")

final_val_sents = create_and_save_qasrl_input(dev_sents,dev)
final_val_sents.to_csv("../../data/predict/extended_fusion/val_extended_fusion_sents.csv", index=False)
create_json_qasrl_input(final_val_sents, "../../data/predict/extended_fusion/val_extended_fusion_sents.json")

final_test_sents = create_and_save_qasrl_input(test_sents,test)
final_test_sents.to_csv("../../data/predict/extended_fusion/test_extended_fusion_sents.csv", index=False)
create_json_qasrl_input(final_test_sents, "../../data/predict/extended_fusion/test_extended_fusion_sents.json")


## At this point we use a QASRL parser to predict on the above sentence files using the released model: http://github.com/nafitzgerald/ nrl- qasrl

### After QASRL prediction

In [109]:
with open("../../data/predict/extended_fusion/train_extended_fusion_sents_pred.json") as f:
    train_sents = [json.loads(line) for line in f.readlines()]
with open("../../data/predict/extended_fusion/val_extended_fusion_sents_pred.json") as f:
    val_sents = [json.loads(line) for line in f.readlines()]
with open("../../data/predict/extended_fusion/test_extended_fusion_sents_pred.json") as f:
    test_sents = [json.loads(line) for line in f.readlines()]

In [110]:
train_qas = format_qasrl(train_sents)
val_qas = format_qasrl(val_sents)
test_qas = format_qasrl(test_sents)
train_post_qasrl = pd.DataFrame(train_qas, columns=['qasrl_id','tokens','verb','question','answer','answer_range', 'verb_idx'])
val_post_qasrl = pd.DataFrame(val_qas, columns=['qasrl_id','tokens','verb','question','answer','answer_range', 'verb_idx'])
test_post_qasrl = pd.DataFrame(test_qas, columns=['qasrl_id','tokens','verb','question','answer','answer_range', 'verb_idx'])

Number of qas extracted:  55424
Number of qas extracted:  8669
Number of qas extracted:  30615


In [88]:
train_post_qasrl.to_csv("../../data/predict/extended_fusion/train_extended_fusion_sents_qasrl_gen.csv",index=False)
val_post_qasrl.to_csv("../../data/predict/extended_fusion/val_extended_fusion_sents_qasrl_gen.csv",index=False)
test_post_qasrl.to_csv("../../data/predict/extended_fusion/test_extended_fusion_sents_qasrl_gen.csv",index=False)

In [93]:
train_post_qasrl = pd.read_csv("../../data/predict/extended_fusion/train_extended_fusion_sents_qasrl_gen.csv")
val_post_qasrl = pd.read_csv("../../data/predict/extended_fusion/val_extended_fusion_sents_qasrl_gen.csv")
test_post_qasrl = pd.read_csv("../../data/predict/extended_fusion/test_extended_fusion_sents_qasrl_gen.csv")

### Once we have the sentence files and the QASRL predicted QAs files, we run: prep_post_qasrl_pre_qaalign.py found under crowdsourcing/QASRL/qasrl_processing/ to combine the two.

### Then we create sentence pairs (the way we create pairs here is unique to a fusion task. Use any relevant method to pair sentences.

In [121]:
train_wqas = pd.read_csv("../../data/predict/extended_fusion/train_extended_fusion_wqas.csv")
val_wqas = pd.read_csv("../../data/predict/extended_fusion/val_extended_fusion_wqas.csv")
test_wqas = pd.read_csv("../../data/predict/extended_fusion/test_extended_fusion_wqas.csv")

In [128]:
test_wqas.iloc[12]

scu_label                   Poachers have endangered wildlife                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                           

In [129]:
train_pairs,scu2size = create_pairs(train_wqas)
print("Train Num Train pairs ",len(train_pairs))
print()
val_pairs,val_scu2size = create_pairs(val_wqas)
print("Val Num Train pairs ",len(val_pairs))
print()
test_pairs,test_scu2size = create_pairs(test_wqas)
print("Test Num Train pairs ",len(test_pairs))
print()

average size clusters  3.4101726551563227
Train Num Train pairs  26321

average size clusters  2.9023809523809523
Val Num Train pairs  2720

average size clusters  3.375841750841751
Test Num Train pairs  14282



In [130]:
train_pairs['key'] = train_pairs.apply(lambda x: get_sorted(x),axis=1)
val_pairs['key'] = val_pairs.apply(lambda x: get_sorted(x),axis=1)
test_pairs['key'] = test_pairs.apply(lambda x: get_sorted(x),axis=1)

In [139]:
train_pairs.to_csv("../../data/predict/extended_fusion/train_extended_fusion_pairs_wqas.csv",index=False)
val_pairs.to_csv("../../data/predict/extended_fusion/val_extended_fusion_pairs_wqas.csv",index=False)
test_pairs.to_csv("../../data/predict/extended_fusion/test_extended_fusion_pairs_wqas.csv",index=False)

26321

In [66]:
def get_sorted(x):
    key = [x['abs_sent_id_1'],x['abs_sent_id_2']]
    key = sorted(key)
    return "~!~".join(key)

In [4]:
'''
['abs_scu_id', 'scu_label', 'abs_sent_id', 'text', 'text_spans', 'year',
       'filepath', 'multi_sent', 'topic', 'w_overlap', 'qas', 'tokens']
'''

def create_pairs(data):
    
    scu2size = defaultdict(list)
    sent2sent = [] #accumulating sentence pairs
    avgs = [] #avg size clusters
    avg_pairs = []
    for i,df in data.groupby("abs_scu_id"):
        if len(df) == 1: 
            print(df)
            return
        
        scu2size[len(df)].append(i)
        avgs.append(len(df))

        p = 0
        for pair in list(itertools.combinations(zip(df['abs_sent_id'],df['text'],df['tokens'], df['year'], df['topic'], df['w_overlap'], df['qas']),2)):
            if pair[0][0] == pair[1][0]: continue
            
            #abs_scu_id,        abs_sent_id_1,text_1,tokens_1 ,abs_sent_id_2,  text_2,tokens_2, year_1, topic_1, w_overlap_1,qas_1, year_2, topic_2, w_overlap_2,qas_2
            sent2sent.append((i, pair[0][0], pair[0][1],pair[0][2],pair[1][0], pair[1][1],pair[1][2], pair[0][3], pair[0][4], pair[0][5],pair[0][6],pair[1][3], pair[1][4], pair[1][5],pair[1][6]))
            p += 1
        avg_pairs.append(p)
    if not sent2sent: return []
    df = pd.DataFrame(sent2sent)
    df.columns = ['abs_scu_id','abs_sent_id_1','text_1','tokens_1','abs_sent_id_2', 'text_2','tokens_2', 'year_1', 'topic_1', 'w_overlap_1','qas_1','year_2', 'topic_2', 'w_overlap_2','qas_2']
    print("average size clusters ", sum(avgs)/len(avgs))

    return df, scu2size


def format_qasrl(sents):
    '''Extracting the necessary keys from predicted QASRL QAs'''
    post_qas = []
    for qa in sents:
        #print("iterating")
        qasrl_id = qa["qasrl_id"]
        #print(len(qa["verbs"]))
        tokens = qa['words']
        for verb_obj in qa["verbs"]:
            #print("\t\t", len(verb_obj['qa_pairs']))
            verb = verb_obj["verb"]
            verb_ind = verb_obj["index"]
            for qa_pair in verb_obj['qa_pairs']:
                row = (qasrl_id, tokens,verb, qa_pair["question"], qa_pair["spans"][0]["text"],
                       str(qa_pair["spans"][0]["start"]) + ":" + str(qa_pair["spans"][0]["end"]), verb_ind)

                post_qas.append(row)
    print("Number of qas extracted: ",len(post_qas))
    return post_qas