In [20]:
import json

In [21]:
minimized = []
with open(f"data/test.english.jsonlines") as f:
        data_lines = f.readlines()
        for line in data_lines:
            minimized.append(json.loads(line))

In [22]:
preds = []
with open(f"output/preds.jsonl") as f:
        data_lines = f.readlines()
        for line in data_lines:
            preds.append(json.loads(line))

In [23]:
doc_to_prediction, doc_to_subtoken_map = preds
keys = sorted(list(doc_to_prediction.keys()), key = lambda x: int(x.split('_')[0].split('/')[1]))

In [24]:
predictions = {}
subtoken_map = {}
for key in keys:
    idx_key = f"faa_{key.split('_')[0].split('/')[1]}_0"
    predictions[idx_key] = doc_to_prediction[key]
    subtoken_map[idx_key] = doc_to_subtoken_map[key]

In [25]:
predictions = doc_to_prediction
subtoken_map = doc_to_subtoken_map

In [26]:
# Using code taken from s2e-coref/conll.py

In [27]:
import collections, operator, re

In [28]:
BEGIN_DOCUMENT_REGEX = re.compile(r"#begin document \((.*)\); part (\d+)")  # First line at each document
COREF_RESULTS_REGEX = re.compile(r".*Coreference: Recall: \([0-9.]+ / [0-9.]+\) ([0-9.]+)%\tPrecision: \([0-9.]+ / [0-9.]+\) ([0-9.]+)%\tF1: ([0-9.]+)%.*", re.DOTALL)

In [29]:
prediction_map = {}
for doc_key, clusters in predictions.items():
    start_map = collections.defaultdict(list)
    end_map = collections.defaultdict(list)
    word_map = collections.defaultdict(list)
    for cluster_id, mentions in enumerate(clusters):
        for start, end in mentions:
            start, end = subtoken_map[doc_key][start], subtoken_map[doc_key][end]
            if start == end:
                word_map[start].append(cluster_id)
            else:
                start_map[start].append((cluster_id, end))
                end_map[end].append((cluster_id, start))
    for k,v in start_map.items():
        start_map[k] = [cluster_id for cluster_id, end in sorted(v, key=operator.itemgetter(1), reverse=True)]
    for k,v in end_map.items():
        end_map[k] = [cluster_id for cluster_id, start in sorted(v, key=operator.itemgetter(1), reverse=True)]
    prediction_map[doc_key] = (start_map, end_map, word_map)

In [31]:
def get_doc_key(doc_id, part):
    return "{}_{}".format(doc_id, int(part))

def output_conll(input_file, output_file, prediction_map):

    word_index = 0
    for line in input_file.readlines():
        row = line.split()
        if len(row) == 0:
            output_file.write("\n")
        elif row[0].startswith("#"):
            begin_match = re.match(BEGIN_DOCUMENT_REGEX, line)
            if begin_match:
                doc_key = get_doc_key(begin_match.group(1), begin_match.group(2))
                start_map, end_map, word_map = prediction_map[doc_key]
                word_index = 0
            output_file.write(line)
            output_file.write("\n")
        else:
            assert get_doc_key(row[0], row[1]) == doc_key
            coref_list = []
            if word_index in end_map:
                for cluster_id in end_map[word_index]:
                    coref_list.append("{})".format(cluster_id))
            if word_index in word_map:
                for cluster_id in word_map[word_index]:
                    coref_list.append("({})".format(cluster_id))
            if word_index in start_map:
                for cluster_id in start_map[word_index]:
                    coref_list.append("({}".format(cluster_id))

            if len(coref_list) == 0:
                row[-1] = "-"
            else:
                row[-1] = "|".join(coref_list)

            output_file.write("   ".join(row))
            output_file.write("\n")
            word_index += 1

In [32]:
input_file = open('data/test.english.v4_gold_conll')
output_file = open('output/preds.conll','w')
output_conll(input_file, output_file, prediction_map)

In [33]:
import pandas as pd

faa_df = pd.read_csv('../../data/FAA_data/Maintenance_Text_data_nona.csv')

In [41]:
outdict = {'c5_id':[], 'c119_input':[], 'start_map':[], 'end_map':[], 'word_map':[], 'corefs_human_readable':[],'corefs':[]}

for idoc in range(len(faa_df)):

    key = keys[idoc]
    
    if key.split('/')[1].split('_')[1] != faa_df['c5'].iat[idoc]:
        print("we have a problem")
    
    outdict['c5_id'].append(faa_df['c5'].iat[idoc])
    outdict['c119_input'].append(faa_df['c119'].iat[idoc])
    outdict['start_map'].append(dict(prediction_map[key][0]))
    outdict['end_map'].append(dict(prediction_map[key][1]))
    outdict['word_map'].append(dict(prediction_map[key][2]))
    
    part = []
    for sentence in minimized[idoc]['sentences']:
        part = part + sentence

    # save to new dicts for easier access
    starts = {}
    for start, word_idx_list in prediction_map[key][0].items():
        for word_idx in word_idx_list:
            starts[word_idx] = starts.get(word_idx, []) + [start]
    
    ends = {}
    for end, word_idx_list in prediction_map[key][1].items():
        for word_idx in word_idx_list:
            ends[word_idx] = ends.get(word_idx, []) + [end]
    
    for word, word_idx_list in prediction_map[key][2].items():
        for word_idx in word_idx_list:
            starts[word_idx] = starts.get(word_idx, []) + [word]
            ends[word_idx] = ends.get(word_idx, []) + [word]

    
    human_corefs = {}
    corefs = {}
    for word_idx in starts.keys():
        for ispan in range(len(starts[word_idx])):
            start = starts[word_idx][ispan]
            end = ends[word_idx][ispan]
            human_corefs[word_idx] = human_corefs.get(word_idx, []) + [' '.join(part[start:end+1])]
            corefs[word_idx] = corefs.get(word_idx, []) + [[start, end]]
    
    outdict['corefs_human_readable'].append(human_corefs)
    outdict['corefs'].append(list(corefs.values()))

In [42]:
pd.DataFrame(outdict).to_csv('../../data/results/s2e-coref/s2e-coref_updated_format.csv', index=False)

In [60]:
out_df = pd.DataFrame(outdict)
out_df[out_df['c5_id'] == '19950826026019A']

Unnamed: 0,c5_id,c119_input,start_map,end_map,word_map,corefs_human_readable,corefs
1920,19950826026019A,"EXPLOSION LIFTING LOGS. PITCHED UP, ROLLED. CO...",{},{},{},{},[]
