### 0. Import libraries and read data

In [1]:
import json
from tqdm import tqdm
from itertools import groupby
from operator import itemgetter

In [2]:
def read_data(input_file_path):
    data = []

    with open(input_file_path) as f:
        for line in f:
            data.append(json.loads(line))

    return data

In [3]:
data = read_data("../../data/squad_data_train_pos_ner.json")

### 1. Aggregate NE for spans for tokens

In [4]:
tags_set = set()

for example in data:
    for ners in example["NER_context"]:
        index = list(ners.keys())[0]
        for ner in list(ners.values())[0]:
            tags_set.add(ner[1])

entities = set([tag.split("-")[-1] for tag in list(tags_set)])
print(f"Set of all available tags:\n{tags_set}\nEntities: {entities}")


Set of all available tags:
{'I-LOC', 'S-LOC', 'S-PER', 'I-ORG', 'B-ORG', 'E-ORG', 'B-MISC', 'S-MISC', 'I-MISC', 'E-PER', 'E-MISC', 'S-ORG', 'E-LOC', 'B-LOC', 'B-PER', 'I-PER'}
Entities: {'MISC', 'LOC', 'ORG', 'PER'}


In [5]:
for line in tqdm(data):
    # get indices of the tokens from the example
    token_indices = [int(list(token.keys())[0]) for token in line["NER_context"]]

    # get spans of continuous indices
    continuous_indices = [list(map(itemgetter(1), g)) for _, g in groupby(enumerate(token_indices), lambda i_x: i_x[0] - i_x[1])]

    words_ne_each_span = []

    # iterate over each indices span
    for indices_span in continuous_indices:

        # if the span is longer than 1
        if len(indices_span) > 1:
            
            # get all the possible NE for all words within the span
            words_ne_within_span = []

            for index in indices_span:
                for index_dict in line["NER_context"]:
                    if list(index_dict.keys())[0] == str(index):
                        words_ne_within_span.append(list(index_dict.values())[0])

            words_ne_each_span.append(words_ne_within_span)

    # extract the ne tags for all words within each tag
    ne_spans = []

    for words_ne_span in words_ne_each_span:
        ne_span = []

        for word_ne in words_ne_span:
            ne_span.append([value[1] for value in word_ne])

        ne_spans.append(ne_span)

    # aggregate NE tags at span level
    # in this list we have the aggregated NE tags for the spans longer than 1
    new_tags_spans = []

    for ne_span in ne_spans:
        new_tags_span = []
        
        found_agg_ne = False
        max_occurence = 0
        agg_entity = None

        for entity in entities:
            occurences_no = len([True for token in ne_span if entity in [tags.split("-")[-1] for tags in token]])
            
            if occurences_no > max_occurence:
                max_occurence = occurences_no
                agg_entity = entity

            if len(ne_span) == occurences_no:
                new_tags_span.append(entity)
                found_agg_ne = True

        if found_agg_ne == False:
            new_tags_span.append(agg_entity)

        new_tags_spans.append(new_tags_span)

    token_index_pos_map = {int(list(token.keys())[0]):index for index, token in enumerate(line["NER_context"])}

    longer_spans_index = -1

    for indices_span in continuous_indices:
        if len(indices_span) == 1:
            tags = list(line["NER_context"][token_index_pos_map[indices_span[0]]].values())[0]
            new_tags = [[tag[0], tag[1].split("-")[-1]] for tag in tags[:2]]
            line["NER_context"][token_index_pos_map[indices_span[0]]][str(indices_span[0])] = new_tags

        else:
            longer_spans_index += 1

            for span_index in indices_span:
                token = list(line["NER_context"][token_index_pos_map[span_index]].values())[0][0][0]

                new_tags = []

                # if no token was found for the span, the most probable token is added
                for tag in new_tags_spans[longer_spans_index]:
                    new_tags.append([token, tag.split("-")[-1]])

                line["NER_context"][token_index_pos_map[span_index]][str(span_index)] = new_tags

  0%|          | 0/87599 [00:00<?, ?it/s]

100%|██████████| 87599/87599 [00:24<00:00, 3583.08it/s]


In [6]:
def write_data(output_file_path, data):
    # Open a new JSON file for writing
    with open(output_file_path, "w") as output_file:
        for data_line in data:
            output_file.write(json.dumps(data_line) + "\n")

In [7]:
write_data("../../data/squad_data_train_pos_ner_agg.json", data)

### 2. Check aggregated NE for spans of tokens

In [8]:
data = read_data("../../data/squad_data_train_pos_ner.json")

In [9]:
data_agg = read_data("../../data/squad_data_train_pos_ner_agg.json")

In [10]:
more_than_two_ne = 0
empty_ne = 0

for line in data_agg:
    for token in line["NER_context"]:
        ne_no = len(list(token.values())[0])

        if ne_no > 2:
            more_than_two_ne += 1
        if ne_no == 0:
            empty_ne += 1

print(f"Number of tokens that have more than 2 ne: {more_than_two_ne}\nNumber of tokens without ne: {empty_ne}")

Number of tokens that have more than 2 ne: 3010
Number of tokens without ne: 0


- for example with index 50000 there are spans that in the end do not have any NE tag

In [17]:
data[888]["NER_context"]

[{'0': [['Beyoncé', 'S-PER', '0.97757655']]},
 {'19': [['US', 'S-LOC', '0.99831045']]},
 {'34': [['Destiny', 'B-ORG', '0.5214504'],
   ['Destiny', 'B-MISC', '0.24072735']]},
 {'35': [["'s", 'I-MISC', '0.5378261'], ["'s", 'I-ORG', '0.2589195']]},
 {'36': [['Child', 'E-MISC', '0.740782'], ['Child', 'E-ORG', '0.21769544']]},
 {'52': [['Recording', 'B-ORG', '0.7557777']]},
 {'53': [['Industry', 'I-ORG', '0.9865926']]},
 {'54': [['Association', 'I-ORG', '0.92794174']]},
 {'55': [['of', 'I-ORG', '0.99845254']]},
 {'56': [['America', 'E-ORG', '0.9987192']]},
 {'58': [['RIAA', 'S-ORG', '0.9977986']]},
 {'61': [['Beyoncé', 'S-ORG', '0.62609166'],
   ['Beyoncé', 'S-PER', '0.26806518']]},
 {'81': [['Crazy', 'B-MISC', '0.98487943']]},
 {'82': [['in', 'I-MISC', '0.9824212']]},
 {'83': [['Love', 'E-MISC', '0.9944148']]},
 {'87': [['Single', 'B-MISC', '0.44712684']]},
 {'88': [['Ladies', 'E-MISC', '0.98573166']]},
 {'90': [['Put', 'B-MISC', '0.71322274']]},
 {'91': [['a', 'I-MISC', '0.99563205']]},
 

In [18]:
data_agg[888]["NER_context"]

[{'0': [['Beyoncé', 'PER']]},
 {'19': [['US', 'LOC']]},
 {'34': [['Destiny', 'MISC'], ['Destiny', 'ORG']]},
 {'35': [["'s", 'MISC'], ["'s", 'ORG']]},
 {'36': [['Child', 'MISC'], ['Child', 'ORG']]},
 {'52': [['Recording', 'ORG']]},
 {'53': [['Industry', 'ORG']]},
 {'54': [['Association', 'ORG']]},
 {'55': [['of', 'ORG']]},
 {'56': [['America', 'ORG']]},
 {'58': [['RIAA', 'ORG']]},
 {'61': [['Beyoncé', 'ORG'], ['Beyoncé', 'PER']]},
 {'81': [['Crazy', 'MISC']]},
 {'82': [['in', 'MISC']]},
 {'83': [['Love', 'MISC']]},
 {'87': [['Single', 'MISC']]},
 {'88': [['Ladies', 'MISC']]},
 {'90': [['Put', 'MISC']]},
 {'91': [['a', 'MISC']]},
 {'92': [['Ring', 'MISC']]},
 {'93': [['on', 'MISC']]},
 {'94': [['It', 'MISC']]},
 {'95': [[')"', 'MISC']]},
 {'98': [['Halo', 'MISC']]},
 {'103': [['Irreplaceable', 'MISC']]},
 {'119': [['The', 'ORG']]},
 {'120': [['Observer', 'ORG']]},
 {'127': [['Decade', 'MISC']]},
 {'129': [['Billboard', 'MISC'], ['Billboard', 'MISC']]},
 {'148': [['Billboard', 'MISC'], ['

In [13]:
# TODO: check tags format when we have 18 classes