### 0. Import libraries and read data

In [1]:
import copy
import json
import tqdm
from itertools import groupby
from operator import itemgetter
from flair.data import Sentence


In [2]:
def read_data(input_file_path):
    data = []

    with open(input_file_path) as f:
        for line in f:
            data.append(json.loads(line))

    return data

In [3]:
def write_data(output_file_path, data):
    # Open a new JSON file for writing
    with open(output_file_path, "w") as output_file:
        for data_line in data:
            output_file.write(json.dumps(data_line) + "\n")

In [4]:
data = read_data("../../data/squad_data_validation_pos_ner_18_single_agg.json")

In [5]:
NER_TAGS = ['[DATE]', '[QUANTITY]', '[PRODUCT]', '[ORDINAL]', '[WORK_OF_ART]', '[LANGUAGE]', '[FAC]', '[EVENT]', '[ORG]', '[LAW]', '[CARDINAL]', '[GPE]', '[PERCENT]', '[NORP]', '[PERSON]', '[MONEY]', '[LOC]', '[TIME]']

### 1. Find the token index where the answer starts

In [6]:
def is_subustring_in_string(string, substring, symbols):
    for symbol in symbols:
        if symbol in string and (substring in string or substring.replace(" ", "") in string.replace(" ", "")):
           return True
        
    return False

In [7]:
def is_one_text_in_context(text, context, symbols):
    count = 0

    for symbol in symbols:
        if context.count(text + symbol):
            count += 1

    return count == 1

In [8]:
def find_start_position(text, pos_context, answer, context):
    tokens = []
    sentence = Sentence(text)
    symbols = [" ", ")", "(", "%", "-", ".", ","]

    # get how many spaces are until the start position
    spaces_no = context[:answer].count(" ")

    # get tokens of the answer
    for word in sentence:
        tokens.append(word.text)

    start_position = 0
    token_number = 0
    match = False
    no_match = -1

    # check if the answer appear only once
    one_apperance = is_one_text_in_context(text, context, symbols)

    # iterate over positions in the pos context
    for start_position in range(len(pos_context) - len(tokens) + 1):
        # store in k_words_context_ and k_words_context the span of tokens from pos_context
        k_words_context_ = [
            bytes(list(line.values())[0][0], "utf-8").decode("unicode_escape")
            for line in pos_context[start_position : start_position + len(tokens)]
        ]

        k_words_context = [
            list(line.values())[0][0].replace("\\", "")
            for line in pos_context[start_position : start_position + len(tokens)]
        ]

        # get the next word after the context
        next_word = [
            list(line.values())[0][0]
            for line in pos_context[
                start_position + len(tokens) : start_position + len(tokens) + 1
            ]
        ]


        # join all words in a string
        k_words_context_string = " ".join(k_words_context)
        k_words_context_string_ = " ".join(k_words_context_)

        # define exact mach condition
        exact_match = k_words_context == tokens or k_words_context_ == tokens


        # if there is only one appearence and the answer was found, stop
        if context.count(text) == 1 and exact_match:
            match = True
            break

        # if there is only one appearence with the answer followed by a token, stop
        if one_apperance and (
            exact_match
            or is_subustring_in_string(k_words_context_string, text, symbols)
            or is_subustring_in_string(k_words_context_string_, text, symbols)
        ):
            match = True
            break

        # if the occurence is not the first one, but the start position is correct, stop
        if (
            text in k_words_context_string
            or text in k_words_context_string_
            or exact_match
            or is_subustring_in_string(k_words_context_string, text, symbols)
            or is_subustring_in_string(k_words_context_string_, text, symbols)
        ) and (
            token_number + spaces_no == answer or token_number + spaces_no + 1 == answer or token_number + spaces_no - 1 == answer
        ):
            match = True
            break

        # else consider the next word
        elif next_word in symbols and (token_number + spaces_no == answer or token_number + spaces_no + 1 == answer or token_number + spaces_no - 1 == answer):
            no_match = start_position
            break

        if "'" in k_words_context_:
            token_number += len(k_words_context_[0])
        else:
            token_number += len(k_words_context[0])

    if match:
        return start_position
    
    return no_match


In [9]:
def extract_answer(line, answer_index):
    original_answer = line["answers"]["text"][answer_index]
    original_answer_start = line["answers"]["answer_start"][answer_index]
    original_answer_list = [token.text for token in list(Sentence(original_answer))]

    index = find_start_position(original_answer, line["POS_context"], original_answer_start, line["context"])
    found_answer = [list(token.values())[0][0] for token in line['POS_context'][index:index+len(original_answer_list)]]

    return original_answer_list, found_answer, index

In [10]:
def check_start_answer_detection(data):
    incorrect_lines_indices = []
    lines_without_any_answer = []
    
    for line_index, line in tqdm.tqdm(enumerate(data)):
        has_answer = False

        for answer_index in range(len(line["answers"]["text"])):
            original_answer_list, found_answer, index = extract_answer(line, answer_index)

            if not " ".join(found_answer).startswith(" ".join(original_answer_list)):
                incorrect_lines_indices.append(str(line_index) + "_" + str(answer_index))
            else:
                has_answer = True
        
        if not has_answer:
            lines_without_any_answer.append(lines_without_any_answer)

    print(f"There are {len(lines_without_any_answer)} line without any answer.")

    return incorrect_lines_indices

### 2. Insert NER in the context

In [11]:
def insert_ne_tags(line, answer_index):
    """
    Find start index for answer
    """
    original_answer_list, found_answer, index = extract_answer(line, answer_index)

    """
    Get spans of continuous indices
    """
    # get tokens
    tokens = [list(token.values())[0][0] for token in line["POS_context"]]

    continuous_spans_indices = [
        list(map(itemgetter(1), g))
        for _, g in groupby(
            enumerate([int(list(token.keys())[0]) for token in line["NER_context"]]),
            lambda i_x: i_x[0] - i_x[1],
        )
    ]

    token_index_pos_map = {int(list(token.keys())[0]):index for index, token in enumerate(line["NER_context"])}


    """
    Insert NE tags and find the new start index for answer
    """
    new_index = 0
    updated_context = []
    span_index = 0
    token_index = 0
    initial_token_index = 0

    while token_index < len(tokens):
        try:
            continuous_spans_indices[span_index][0]
        except:
            updated_context.append(tokens[token_index])
            if initial_token_index == index:
                new_index = len(updated_context) - 1

            token_index += 1
            initial_token_index += 1
            continue

        if token_index + 1 > continuous_spans_indices[span_index][0] and len(continuous_spans_indices[span_index]) == 1:
            ne_tags = [tag[1] for tag in list(line["NER_context"][token_index_pos_map[token_index]].values())[0]]

            for ne_tag in ne_tags:
                updated_context.append("[" + ne_tag + "]")

            updated_context.append(tokens[token_index])

            if initial_token_index == index:
                
                if updated_context[-2] in NER_TAGS:
                    new_index = len(updated_context) - 1 - len(ne_tags)
                else:
                    new_index = len(updated_context) - 1

            for ne_tag in ne_tags[::-1]:
                updated_context.append("[" + ne_tag + "]")

            span_index += 1
            token_index += 1
            initial_token_index += 1


        elif token_index + 1 > continuous_spans_indices[span_index][0] and len(continuous_spans_indices[span_index]) > 1:
            ne_tags = [tag[1] for tag in list(line["NER_context"][token_index_pos_map[token_index]].values())[0]]

            for ne_tag in ne_tags:
                updated_context.append("[" + ne_tag + "]")

            for i in range(len(continuous_spans_indices[span_index])):
                updated_context.append(tokens[token_index])

                if initial_token_index == index:
                    if updated_context[-2] in NER_TAGS:
                        new_index = len(updated_context) - 1 - len(ne_tags)
                    else:
                        new_index = len(updated_context) - 1

                token_index += 1
                initial_token_index += 1

            for ne_tag in ne_tags[::-1]:
                updated_context.append("[" + ne_tag + "]")

            span_index += 1

        else:
            updated_context.append(tokens[token_index])
            if initial_token_index == index:
                new_index = len(updated_context) - 1

            token_index += 1
            initial_token_index += 1

        if token_index == len(tokens) - 1:
            updated_context.append(tokens[token_index])
            if initial_token_index == index:
                new_index = len(updated_context) - 1

    return updated_context, new_index, original_answer_list, tokens, index

### 3. Find last index of the answer

In [12]:
def find_last_index(updated_context, new_index, original_answer_list, tokens, index):
    """
    Find last index of answer in the updated context
    """
    new_last_index = new_index
    searched_answer = copy.deepcopy(original_answer_list)

    searching_answer_iter = -1

    while len(searched_answer) != 0 or updated_context[new_last_index] in NER_TAGS:
        searching_answer_iter += 1

        if len(searched_answer) != 0:
            if searched_answer[0] in updated_context[new_last_index]:
                searched_answer.pop(0)
                new_last_index += 1

        try:
            if updated_context[new_last_index] in NER_TAGS:
                new_last_index += 1
        except:
            pass
            
        if new_last_index == len(updated_context) or searching_answer_iter == 100:
            break

    original_answer = tokens[index:index+len(original_answer_list)]
    updated_answer = updated_context[new_index:new_last_index]

    if original_answer != [token for token in updated_answer if token not in NER_TAGS]:
        return -1
    
    return new_last_index

### 4. Process each line and update the data

In [13]:
incorrect_lines_indices = check_start_answer_detection(data)
print(f"Number of answers for which the start index was not detected correctly: {len(incorrect_lines_indices)}")

10570it [00:31, 340.58it/s]

There are 15 line without any answer.
Number of answers for which the start index was not detected correctly: 167





In [14]:
print("Cases in which the start index was not detected correctly:\n")

for incorrect_index in incorrect_lines_indices:
    line_index, answer_index = [int(val) for val in incorrect_index.split("_")]

    line = data[line_index]
    original_answer_list, found_answer, index = extract_answer(line, answer_index)

    print(original_answer_list, found_answer, index)

Cases in which the start index was not detected correctly:

['8'] ['18'] 91
['2'] ['12'] 65
['10', '.'] ['2010', '.'] 54
['7'] [] -1
['39'] ['739'] 113
['C.', 'J.', 'Anderson'] [] -1
['C.', 'J.', 'Anderson'] [] -1
['5'] ['1,135'] 67
['C.', 'J.', 'Anderson'] [] -1
['C.', 'J.', 'Anderson'] [] -1
['L'] [] -1
['L', '.'] ['Bowl', 'L.'] 52
['0'] ['30'] 25
['ten'] [] -1
['two'] [] -1
['28'] [] -1
['Ted', 'Ginn', 'Jr', '.'] ['to', 'Ted', 'Ginn', 'Jr.'] 20
['T.', 'J.', 'Ward'] [] -1
['T.', 'J.', 'Ward'] [] -1
['Ted', 'Ginn', 'Jr', '.'] ['to', 'Ted', 'Ginn', 'Jr.'] 20
['Ted', 'Ginn', 'Jr', '.'] ['to', 'Ted', 'Ginn', 'Jr.'] 20
['T.', 'J.', 'Ward', '.'] [] -1
['T.', 'J.', 'Ward'] [] -1
['Ted', 'Ginn', 'Jr', '.'] ['to', 'Ted', 'Ginn', 'Jr.'] 20
['wards'] [] -1
['wards'] [] -1
['e', 'Red', 'Army'] ['the', 'Red', 'Army'] 4
['Epte'] ['Saint-Clair-sur-Epte'] 53
['American', 'humor', '.'] ['American', 'humor.', '"'] 133
['street', 'cars'] ["'s", 'streetcars'] 20
['J.', 'P.', 'Morgan'] [] -1
['J.', 'P.',

In [15]:
indicies_without_answer = copy.deepcopy(incorrect_lines_indices)

for line_index, line in tqdm.tqdm(enumerate(data)):
    number_of_answers = len(data[line_index]["answers"]["text"])

    for answer_index in range(number_of_answers):
        updated_context, new_index, original_answer_list, tokens, index = insert_ne_tags(line, answer_index)

        if str(line_index) + "_" + str(answer_index) not in indicies_without_answer:
            new_last_index = find_last_index(updated_context, new_index, original_answer_list, tokens, index)

            if new_last_index == -1:
                indicies_without_answer.append(str(line_index) + "_" + str(answer_index))

            else:
                data[line_index]["answers"]["text"][answer_index] = " ".join(updated_context[new_index:new_last_index])
                
                if new_index == 0:
                    data[line_index]["answers"]["answer_start"][answer_index] = 0
                else:    
                    data[line_index]["answers"]["answer_start"][answer_index] = len(" ".join(updated_context[:new_index])) + 1

        if str(line_index) + "_" + str(answer_index) in indicies_without_answer:
            data[line_index]["answers"]["text"][answer_index] = ""
            data[line_index]["answers"]["answer_start"][answer_index] = -1           


    data[line_index]["context"] = " ".join(updated_context) 

    columns_to_delete = ["NER_context", "POS_context", "NER_question", "POS_question"]    

    for column_to_delete in columns_to_delete:
        try:
            del data[line_index][column_to_delete]
        except:
            pass

10570it [00:39, 269.09it/s]


### 5. Store the updated data

In [18]:
write_data("../../data/squad_data_validation_ner_span_single_18.json", data)