### 0. Import libraries and read data

In [1]:
import copy
import json
import tqdm
from itertools import groupby
from operator import itemgetter
from flair.data import Sentence


In [2]:
def read_data(input_file_path):
    data = []

    with open(input_file_path) as f:
        for line in f:
            data.append(json.loads(line))

    return data

In [3]:
def write_data(output_file_path, data):
    # Open a new JSON file for writing
    with open(output_file_path, "w") as output_file:
        for data_line in data:
            output_file.write(json.dumps(data_line) + "\n")

In [18]:
data = read_data("../../data/squad_data_train_pos_ner_agg.json")

In [5]:
NER_TAGS = ['[LOC]', '[MISC]', '[ORG]', '[PER]']

### 1. Find the token index where the answer starts

In [6]:
def is_subustring_in_string(string, substring, symbols):
    for symbol in symbols:
        if symbol in string and (substring in string or substring.replace(" ", "") in string.replace(" ", "")):
           return True
        
    return False

In [7]:
def is_one_text_in_context(text, context, symbols):
    count = 0

    for symbol in symbols:
        if context.count(text + symbol):
            count += 1

    return count == 1

In [8]:
def find_start_position(text, pos_context, answer, context):
    tokens = []
    sentence = Sentence(text)
    symbols = [" ", ")", "(", "%", "-", ".", ","]

    # get how many spaces are until the start position
    spaces_no = context[:answer].count(" ")

    # get tokens of the answer
    for word in sentence:
        tokens.append(word.text)

    start_position = 0
    token_number = 0
    match = False
    no_match = -1

    # check if the answer appear only once
    one_apperance = is_one_text_in_context(text, context, symbols)

    # iterate over positions in the pos context
    for start_position in range(len(pos_context) - len(tokens) + 1):
        # store in k_words_context_ and k_words_context the span of tokens from pos_context
        k_words_context_ = [
            bytes(list(line.values())[0][0], "utf-8").decode("unicode_escape")
            for line in pos_context[start_position : start_position + len(tokens)]
        ]

        k_words_context = [
            list(line.values())[0][0].replace("\\", "")
            for line in pos_context[start_position : start_position + len(tokens)]
        ]

        # get the next word after the context
        next_word = [
            list(line.values())[0][0]
            for line in pos_context[
                start_position + len(tokens) : start_position + len(tokens) + 1
            ]
        ]


        # join all words in a string
        k_words_context_string = " ".join(k_words_context)
        k_words_context_string_ = " ".join(k_words_context_)

        # define exact mach condition
        exact_match = k_words_context == tokens or k_words_context_ == tokens


        # if there is only one appearence and the answer was found, stop
        if context.count(text) == 1 and exact_match:
            match = True
            break

        # if there is only one appearence with the answer followed by a token, stop
        if one_apperance and (
            exact_match
            or is_subustring_in_string(k_words_context_string, text, symbols)
            or is_subustring_in_string(k_words_context_string_, text, symbols)
        ):
            match = True
            break

        # if the occurence is not the first one, but the start position is correct, stop
        if (
            text in k_words_context_string
            or text in k_words_context_string_
            or exact_match
            or is_subustring_in_string(k_words_context_string, text, symbols)
            or is_subustring_in_string(k_words_context_string_, text, symbols)
        ) and (
            token_number + spaces_no == answer or token_number + spaces_no + 1 == answer or token_number + spaces_no - 1 == answer
        ):
            match = True
            break

        # else consider the next word
        elif next_word in symbols and (token_number + spaces_no == answer or token_number + spaces_no + 1 == answer or token_number + spaces_no - 1 == answer):
            no_match = start_position
            break

        if "'" in k_words_context_:
            token_number += len(k_words_context_[0])
        else:
            token_number += len(k_words_context[0])

    if match:
        return start_position
    
    return no_match


In [10]:
def extract_answer(line):
    original_answer = line["answers"]["text"][0]
    original_answer_start = line["answers"]["answer_start"][0]
    original_answer_list = [token.text for token in list(Sentence(original_answer))]

    index = find_start_position(original_answer, line["POS_context"], original_answer_start, line["context"])
    found_answer = [list(token.values())[0][0] for token in line['POS_context'][index:index+len(original_answer_list)]]

    return original_answer_list, found_answer, index

In [9]:
def check_start_answer_detection(data):
    incorrect_lines_indices = []
    
    for i, line in tqdm.tqdm(enumerate(data)):
        original_answer_list, found_answer, index = extract_answer(line)

        if not " ".join(found_answer).startswith(" ".join(original_answer_list)):
            incorrect_lines_indices.append(i)

    return incorrect_lines_indices

### 2. Insert NER in the context

In [11]:
def insert_ne_tags(line):
    """
    Find start index for answer
    """
    original_answer_list, found_answer, index = extract_answer(line)

    """
    Get spans of continuous indices
    """
    # get tokens
    tokens = [list(token.values())[0][0] for token in line["POS_context"]]

    continuous_spans_indices = [
        list(map(itemgetter(1), g))
        for _, g in groupby(
            enumerate([int(list(token.keys())[0]) for token in line["NER_context"]]),
            lambda i_x: i_x[0] - i_x[1],
        )
    ]

    token_index_pos_map = {int(list(token.keys())[0]):index for index, token in enumerate(line["NER_context"])}


    """
    Insert NE tags and find the new start index for answer
    """
    new_index = 0
    updated_context = []
    span_index = 0
    token_index = 0
    initial_token_index = 0

    while token_index < len(tokens):
        try:
            continuous_spans_indices[span_index][0]
        except:
            updated_context.append(tokens[token_index])
            if initial_token_index == index:
                new_index = len(updated_context) - 1

            token_index += 1
            initial_token_index += 1
            continue

        if token_index + 1 > continuous_spans_indices[span_index][0] and len(continuous_spans_indices[span_index]) == 1:
            ne_tags = [tag[1] for tag in list(line["NER_context"][token_index_pos_map[token_index]].values())[0]]

            for ne_tag in ne_tags:
                updated_context.append("[" + ne_tag + "]")

            updated_context.append(tokens[token_index])

            if initial_token_index == index:
                
                if updated_context[-2] in NER_TAGS:
                    new_index = len(updated_context) - 1 - len(ne_tags)
                else:
                    new_index = len(updated_context) - 1

            for ne_tag in ne_tags[::-1]:
                updated_context.append("[" + ne_tag + "]")

            span_index += 1
            token_index += 1
            initial_token_index += 1


        elif token_index + 1 > continuous_spans_indices[span_index][0] and len(continuous_spans_indices[span_index]) > 1:
            ne_tags = [tag[1] for tag in list(line["NER_context"][token_index_pos_map[token_index]].values())[0]]

            for ne_tag in ne_tags:
                updated_context.append("[" + ne_tag + "]")

            for i in range(len(continuous_spans_indices[span_index])):
                updated_context.append(tokens[token_index])

                if initial_token_index == index:
                    if updated_context[-2] in NER_TAGS:
                        new_index = len(updated_context) - 1 - len(ne_tags)
                    else:
                        new_index = len(updated_context) - 1

                token_index += 1
                initial_token_index += 1

            for ne_tag in ne_tags[::-1]:
                updated_context.append("[" + ne_tag + "]")

            span_index += 1

        else:
            updated_context.append(tokens[token_index])
            if initial_token_index == index:
                new_index = len(updated_context) - 1

            token_index += 1
            initial_token_index += 1

        if token_index == len(tokens) - 1:
            updated_context.append(tokens[token_index])
            if initial_token_index == index:
                new_index = len(updated_context) - 1

    return updated_context, new_index, original_answer_list, tokens, index

### 3. Find last index of the answer

In [12]:
def find_last_index(updated_context, new_index, original_answer_list, tokens, index):
    """
    Find last index of answer in the updated context
    """
    new_last_index = new_index
    searched_answer = copy.deepcopy(original_answer_list)

    searching_answer_iter = -1

    while len(searched_answer) != 0 or updated_context[new_last_index] in NER_TAGS:
        searching_answer_iter += 1

        if len(searched_answer) != 0:
            if searched_answer[0] in updated_context[new_last_index]:
                searched_answer.pop(0)
                new_last_index += 1

        try:
            if updated_context[new_last_index] in NER_TAGS:
                new_last_index += 1
        except:
            pass
            
        if new_last_index == len(updated_context) or searching_answer_iter == 100:
            break

    original_answer = tokens[index:index+len(original_answer_list)]
    updated_answer = updated_context[new_index:new_last_index]

    if original_answer != [token for token in updated_answer if token not in NER_TAGS]:
        return -1
    
    return new_last_index

### 4. Process each line and update the data

In [13]:
incorrect_lines_indices = check_start_answer_detection(data)
print(f"Number of lines for which the start index was not detected correctly: {len(incorrect_lines_indices)}")

87599it [01:24, 1038.17it/s]

Number of lines for which the start index was not detected correctly: 398





In [14]:
print("Cases in which the start index was not detected correctly:\n")

for incorrect_index in incorrect_lines_indices:
    line = data[incorrect_index]
    
    original_answer_list, found_answer, index = extract_answer(line)

    print(original_answer_list, found_answer, index)

Cases in which the start index was not detected correctly:

['Father', 'Joseph', 'Carrier', ',', 'C.S.C', '.'] ['Father', 'Joseph', 'Carrier', ',', 'C.S.C.', 'was'] 0
['Rev', '.', 'John', 'J.', 'Cavanaugh', ',', 'C.S.C', '.'] ['The', 'Rev', '.', 'John', 'J.', 'Cavanaugh', ',', 'C.S.C.'] 0
['five', '.'] ['top-five', '.'] 149
['Jay', 'Z', '.'] ['married', 'Jay', 'Z.'] 7
["'", '03', 'Bonnie', '&', 'Clyde'] ['"\'', '03', 'Bonnie', '&', 'Clyde'] 16
['B.I.C', '.'] ['"', 'B.I.C.'] 91
['B.I.C', '.'] ['"', 'B.I.C.'] 91
['50'] [] -1
['J.', 'S.', 'Bach', ',', 'Mozart', 'and', 'Schubert'] [] -1
['J.', 'S.', 'Bach', ',', 'Mozart', 'and', 'Schubert'] [] -1
['J.', 'S.', 'Bach', ',', 'Mozart', 'and', 'Schubert'] [] -1
['7'] [] -1
['7'] [] -1
['7'] [] -1
['Rondo', 'Op', '.', '1', '.'] ['his', 'Rondo', 'Op', '.', '1.'] 162
['the', 'Canuts', '.'] [',', 'the', 'Canuts.'] 42
['J.S', '.', 'Bach', "'s", 'The', 'Well-Tempered', 'Clavier'] ['by', 'J.S.', 'Bach', "'s", 'The', 'Well-Tempered', 'Clavier'] 38
['Op

In [15]:
indicies_without_answer = copy.deepcopy(incorrect_lines_indices)

for line_index, line in tqdm.tqdm(enumerate(data)):
    updated_context, new_index, original_answer_list, tokens, index = insert_ne_tags(line)

    if line_index not in incorrect_lines_indices:
        new_last_index = find_last_index(updated_context, new_index, original_answer_list, tokens, index)
        
        if new_last_index == -1:
            indicies_without_answer.append(line_index)

        else:
            data[line_index]["answers"]["text"] = [" ".join(updated_context[new_index:new_last_index])]
            
            if new_index == 0:
                data[line_index]["answers"]["answer_start"] = [0]
            else:    
                data[line_index]["answers"]["answer_start"] = [len(" ".join(updated_context[:new_index])) + 1]

    data[line_index]["context"] = " ".join(updated_context)

    if line_index in indicies_without_answer:
        data[line_index]["answers"]["text"] = [""]
        data[line_index]["answers"]["answer_start"] = [-1]

    columns_to_delete = ["NER_context", "POS_context", "NER_question", "POS_question"]    

    for column_to_delete in columns_to_delete:
        try:
            del data[line_index][column_to_delete]
        except:
            pass

87599it [01:37, 896.92it/s] 


### 5. Store the updated data

In [16]:
write_data("../../data/squad_data_train_ner_span.json", data)