### 0. Import libraries and read data

In [1]:
import json
from tqdm import tqdm
from itertools import groupby
from operator import itemgetter

In [2]:
def read_data(input_file_path):
    data = []

    with open(input_file_path) as f:
        for line in f:
            data.append(json.loads(line))

    return data

In [3]:
data = read_data("../../data/squad_data_validation_pos_ner.json")

### 1. Aggregate NE for spans for tokens

In [4]:
tags_set = set()

for example in data:
    for ners in example["NER_context"]:
        index = list(ners.keys())[0]
        for ner in list(ners.values())[0]:
            tags_set.add(ner[1])

entities = set([tag.split("-")[-1] for tag in list(tags_set)])
print(f"Set of all available tags:\n{tags_set}\nEntities: {entities}")


Set of all available tags:
{'I-MISC', 'E-MISC', 'B-LOC', 'I-ORG', 'E-ORG', 'S-ORG', 'E-LOC', 'B-PER', 'S-PER', 'S-MISC', 'E-PER', 'S-LOC', 'B-ORG', 'I-PER', 'I-LOC', 'B-MISC'}
Entities: {'LOC', 'MISC', 'ORG', 'PER'}


In [5]:
for line in tqdm(data):
    # get indices of the tokens from the example
    token_indices = [int(list(token.keys())[0]) for token in line["NER_context"]]

    # get spans of continuous indices
    continuous_indices = [list(map(itemgetter(1), g)) for _, g in groupby(enumerate(token_indices), lambda i_x: i_x[0] - i_x[1])]

    words_ne_each_span = []

    # iterate over each indices span
    for indices_span in continuous_indices:

        # if the span is longer than 1
        if len(indices_span) > 1:
            
            # get all the possible NE for all words within the span
            words_ne_within_span = []

            for index in indices_span:
                for index_dict in line["NER_context"]:
                    if list(index_dict.keys())[0] == str(index):
                        words_ne_within_span.append(list(index_dict.values())[0])

            words_ne_each_span.append(words_ne_within_span)

    # extract the ne tags for all words within each tag
    ne_spans = []

    for words_ne_span in words_ne_each_span:
        ne_span = []

        for word_ne in words_ne_span:
            ne_span.append([value[1] for value in word_ne])

        ne_spans.append(ne_span)

    # aggregate NE tags at span level
    # in this list we have the aggregated NE tags for the spans longer than 1
    new_tags_spans = []

    for ne_span in ne_spans:
        new_tags_span = []
        
        found_agg_ne = False
        max_occurence = 0
        agg_entity = None

        for entity in entities:
            occurences_no = len([True for token in ne_span if entity in [tags.split("-")[-1] for tags in token]])
            
            if occurences_no > max_occurence:
                max_occurence = occurences_no
                agg_entity = entity

            if len(ne_span) == occurences_no:
                new_tags_span.append(entity)
                found_agg_ne = True

        if found_agg_ne == False:
            new_tags_span.append(agg_entity)

        new_tags_spans.append(new_tags_span)

    token_index_pos_map = {int(list(token.keys())[0]):index for index, token in enumerate(line["NER_context"])}

    longer_spans_index = -1

    for indices_span in continuous_indices:
        if len(indices_span) == 1:
            tags = list(line["NER_context"][token_index_pos_map[indices_span[0]]].values())[0]
            new_tags = [[tag[0], tag[1].split("-")[-1]] for tag in tags[:2]]
            line["NER_context"][token_index_pos_map[indices_span[0]]][str(indices_span[0])] = new_tags

        else:
            longer_spans_index += 1

            for span_index in indices_span:
                token = list(line["NER_context"][token_index_pos_map[span_index]].values())[0][0][0]

                new_tags = []

                # if no token was found for the span, the most probable token is added
                for tag in new_tags_spans[longer_spans_index]:
                    new_tags.append([token, tag.split("-")[-1]])

                line["NER_context"][token_index_pos_map[span_index]][str(span_index)] = new_tags

100%|██████████| 10570/10570 [00:09<00:00, 1098.81it/s]


In [6]:
def write_data(output_file_path, data):
    # Open a new JSON file for writing
    with open(output_file_path, "w") as output_file:
        for data_line in data:
            output_file.write(json.dumps(data_line) + "\n")

In [7]:
write_data("../../data/squad_data_validation_pos_ner_18_agg.json", data)

### 2. Check aggregated NE for spans of tokens

In [8]:
data = read_data("../../data/squad_data_validation_pos_ner_18.json")

In [9]:
data_agg = read_data("../../data/squad_data_validation_pos_ner_18_agg.json")

In [10]:
more_than_two_ne = 0

for line in data_agg:
    for token in line["NER_context"]:
        ne_no = len(list(token.values())[0])

        if ne_no > 2:
            more_than_two_ne += 1

print(f"Number of tokens that have more than 2 ne: {more_than_two_ne}")

Number of tokens that have more than 2 ne: 339


In [18]:
data[888]["NER_context"]

[{'3': [['syrenka', 'S-PERSON', '0.37249884']]},
 {'6': [['Warsaw', 'S-GPE', '0.9950407']]},
 {'34': [['at', 'B-DATE', '0.11415275']]},
 {'35': [['least', 'I-DATE', '0.54797506']]},
 {'36': [['the', 'I-DATE', '0.5037803']]},
 {'37': [['mid-14th', 'I-DATE', '0.93928146']]},
 {'38': [['century', 'E-DATE', '0.9985323']]},
 {'46': [['Warsaw', 'S-GPE', '0.9953707']]},
 {'50': [['year', 'E-DATE', '0.28886825'],
   ['year', 'I-DATE', '0.19738819'],
   ['year', 'S-DATE', '0.10291117']]},
 {'51': [['1390', 'E-DATE', '0.98943406']]},
 {'61': [['Latin', 'S-NORP', '0.95007807']]},
 {'63': [['Sigilium', 'S-NORP', '0.28365117']]},
 {'64': [['Civitatis', 'E-PERSON', '0.106351845']]},
 {'65': [['Varsoviensis', 'E-ORG', '0.1934269'],
   ['Varsoviensis', 'E-PERSON', '0.10567914']]},
 {'72': [['Warsaw', 'S-GPE', '0.9180408']]},
 {'75': [['City', 'S-GPE', '0.10257403']]},
 {'81': [['1609', 'S-DATE', '0.96907514']]},
 {'107': [['1653', 'S-DATE', '0.9954269']]},
 {'110': [['Zygmunt', 'B-PERSON', '0.9383151'

In [19]:
data_agg[888]["NER_context"]

[{'3': [['syrenka', 'PERSON']]},
 {'6': [['Warsaw', 'GPE']]},
 {'34': [['at', 'DATE']]},
 {'35': [['least', 'DATE']]},
 {'36': [['the', 'DATE']]},
 {'37': [['mid-14th', 'DATE']]},
 {'38': [['century', 'DATE']]},
 {'46': [['Warsaw', 'GPE']]},
 {'50': [['year', 'DATE']]},
 {'51': [['1390', 'DATE']]},
 {'61': [['Latin', 'NORP']]},
 {'63': [['Sigilium', 'PERSON']]},
 {'64': [['Civitatis', 'PERSON']]},
 {'65': [['Varsoviensis', 'PERSON']]},
 {'72': [['Warsaw', 'GPE']]},
 {'75': [['City', 'GPE']]},
 {'81': [['1609', 'DATE']]},
 {'107': [['1653', 'DATE']]},
 {'110': [['Zygmunt', 'PERSON']]},
 {'111': [['Laukowski', 'PERSON']]}]

In [13]:
# TODO: check tags format when we have 18 classes