In [2]:
import csv
from tqdm import tqdm
from flair.data import Sentence
from flair.models import SequenceTagger
import re

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
with open("../vasari-kg.github.io/data/sentences_en.csv", "r") as f:
    sentences = list(csv.DictReader(f=f, delimiter=","))

In [4]:
tagger = SequenceTagger.load("flair/ner-english-ontonotes-large")

2022-06-15 12:34:10,845 loading file C:\Users\CSA\.flair\models\ner-english-ontonotes-large\2da6c2cdd76e59113033adf670340bfd820f0301ae2e39204d67ba2dc276cc28.ec1bdb304b6c66111532c3b1fc6e522460ae73f1901848a4d0362cdf9760edb1
2022-06-15 12:34:34,794 SequenceTagger predicts: Dictionary with 76 tags: <unk>, O, B-CARDINAL, E-CARDINAL, S-PERSON, S-CARDINAL, S-PRODUCT, B-PRODUCT, I-PRODUCT, E-PRODUCT, B-WORK_OF_ART, I-WORK_OF_ART, E-WORK_OF_ART, B-PERSON, E-PERSON, S-GPE, B-DATE, I-DATE, E-DATE, S-ORDINAL, S-LANGUAGE, I-PERSON, S-EVENT, S-DATE, B-QUANTITY, E-QUANTITY, S-TIME, B-TIME, I-TIME, E-TIME, B-GPE, E-GPE, S-ORG, I-GPE, S-NORP, B-FAC, I-FAC, E-FAC, B-NORP, E-NORP, S-PERCENT, B-ORG, E-ORG, B-LANGUAGE, E-LANGUAGE, I-CARDINAL, I-ORG, S-WORK_OF_ART, I-QUANTITY, B-MONEY


In [5]:
def convert_label(label):
    if label == "PERSON":
        return "PER"
    if label == "ORGANIZATION":
        return "ORG"
    if label in {"GPE", "FAC"}:
        return "LOC"
    if label =="DATE":
        return "DATE"
    else:
        return "MISC"

In [6]:
output = []

pbar = tqdm(total=len(sentences))
for sample in sentences:
    sent_idx = sample["id"]
    text = sample["sentence"]
    sentence = Sentence(text)
    tagger.predict(sentence)
    for entity in sentence.get_spans("ner"):
        start_pos = entity.start_position
        end_pos = entity.end_position
        surface = entity.text
        ner_type = entity.get_label("ner").value
        score = entity.get_label("ner").score
        match = re.match('(?:^the\s|^a\s)(.*?)$', surface, re.IGNORECASE)
        if match:
            surface = match.group(1)
            start_pos = end_pos - len(surface)
        if ner_type not in {"TIME", "PERCENT", "MONEY", "QUANTITY", "ORDINAL","CARDINAL"}:
            ner_type = convert_label(ner_type)
            output.append({
                "id":sent_idx,
                "start_pos":start_pos,
                "end_pos":end_pos,
                "surface":surface,
                "type":ner_type,
                "score":score
            })
    pbar.update(1)
pbar.close()

keys = output[0].keys()
a_file = open("results3/ontonotes_en/output.csv", "w")
dict_writer = csv.DictWriter(a_file, keys)
dict_writer.writeheader()
dict_writer.writerows(output)
a_file.close()

100%|██████████████████████████████████████████████████████████████████████████████████| 40/40 [00:38<00:00,  1.05it/s]
