In [1]:
import sys
sys.path.append("../src")
from mistral_client import run_mistral
from ner_post_processing import parse_entities_promptner, get_token_labels

import pandas as pd
from tqdm import tqdm
from datasets import load_dataset
from collections import defaultdict

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
dataset = load_dataset("DFKI-SLT/cross_ner", "politics")

In [3]:
class_labels = dataset["validation"].features["ner_tags"].feature.names
index2label = {i: label for i, label in enumerate(class_labels)}
label2index = {v: k for k, v in index2label.items()}

label2index

{'O': 0,
 'B-academicjournal': 1,
 'I-academicjournal': 2,
 'B-album': 3,
 'I-album': 4,
 'B-algorithm': 5,
 'I-algorithm': 6,
 'B-astronomicalobject': 7,
 'I-astronomicalobject': 8,
 'B-award': 9,
 'I-award': 10,
 'B-band': 11,
 'I-band': 12,
 'B-book': 13,
 'I-book': 14,
 'B-chemicalcompound': 15,
 'I-chemicalcompound': 16,
 'B-chemicalelement': 17,
 'I-chemicalelement': 18,
 'B-conference': 19,
 'I-conference': 20,
 'B-country': 21,
 'I-country': 22,
 'B-discipline': 23,
 'I-discipline': 24,
 'B-election': 25,
 'I-election': 26,
 'B-enzyme': 27,
 'I-enzyme': 28,
 'B-event': 29,
 'I-event': 30,
 'B-field': 31,
 'I-field': 32,
 'B-literarygenre': 33,
 'I-literarygenre': 34,
 'B-location': 35,
 'I-location': 36,
 'B-magazine': 37,
 'I-magazine': 38,
 'B-metrics': 39,
 'I-metrics': 40,
 'B-misc': 41,
 'I-misc': 42,
 'B-musicalartist': 43,
 'I-musicalartist': 44,
 'B-musicalinstrument': 45,
 'I-musicalinstrument': 46,
 'B-musicgenre': 47,
 'I-musicgenre': 48,
 'B-organisation': 49,
 'I-o

In [4]:
dataset

DatasetDict({
    train: Dataset({
        features: ['id', 'tokens', 'ner_tags'],
        num_rows: 200
    })
    validation: Dataset({
        features: ['id', 'tokens', 'ner_tags'],
        num_rows: 541
    })
    test: Dataset({
        features: ['id', 'tokens', 'ner_tags'],
        num_rows: 651
    })
})

In [5]:
#  You are an expert linguist. Your task 
 
prompt = lambda text: f"""
Dfn: An entity is a person (person), organisation (organisation), politician (politician), political party (politicalparty), event (event), election (election), 
country (country), location (location) or other political entity (misc). Dates, times, abstract concepts, adjectives, and verbs are not entities.

Example 1: Sitting as a Liberal Party of Canada Member of Parliament (MP) for Niagara Falls, she joined the Canadian Cabinet after the Liberals defeated the 
Progressive Conservative Party of Canada government of John Diefenbaker in the 1963 Canadian federal election.

Answer:
1. Liberal Party of Canada | True | as it is a political party (politicalparty)
2. Parliament | True | as it is an organisation (organisation)
3. Niagara Falls | True | as it is a location (location)
4. Canadian Cabinet | True | as it is a political entity (misc)
5. Liberals | True | as it is a political group by not the party name (misc)
6. Progressive Conservative Party of Canada | True | as it is a political party (politicalparty)
7. government | False | as it is not actually an entity in this sentence
8. John Diefenbaker | True | as it is a politician (politician)
9. 1963 Canadian federal election | True | as it is an election (election)

Example 2: The MRE took part to the consolidation of The Olive Tree as a joint electoral list both for the
2004 European Parliament election and the 2006 Italian general election, along with the Democrats of the Left
and Democracy is Freedom - The Daisy.

Answer:
1. MRE | True | as it is a political party (politicalparty)
2. consolidation | False | as it is an action
3. The Olive Tree | True | as it is a group or organisation (organisation)
4. 2004 European Parliament election | True | as it is an election (election)
5. 2006 Italian general election | True | as it is an election (election)
6. Democrats of the Left | True | as it is a political party (politicalparty)
7. Democracy is Freedom - The Daisy | True | as it is an political party (politicalparty)

Q. Given the paragraph below, identify a list of possible entities and for each entry explain why it either is or is not an entity.

Paragraph: {text}
"""


In [6]:
import evaluate

metric = evaluate.load("seqeval")

def score_ner(prediction_batch, gold_batch):
    labeled_predicions = []
    for prediction in prediction_batch:
        labeled_predicions.append([index2label[i] for i in prediction])
    labeled_gold = []
    for gold in gold_batch:
        labeled_gold.append([index2label[i] for i in gold])
    return metric.compute(
        predictions=labeled_predicions, 
        references=labeled_gold
    )

In [7]:
scored = defaultdict(list)

for idx, example in enumerate(tqdm(dataset["validation"].select(range(5)))):
    if (idx + 1) % 100 == 0:
        df_scored = pd.DataFrame(scored)
        df_scored.to_csv("../data/scored/validation.csv", index=False)
    
    text = " ".join(example["tokens"])
    prompt_input = prompt(text)
    output = run_mistral(prompt_input)
    ner_tags = get_token_labels(text, parse_entities_promptner(output), label2index)

    scored["id"].append(example["id"])
    scored["tokens"].append(example["tokens"])
    scored["prompt"].append(prompt_input)
    scored["output"].append(output)
    scored["ner_tags"].append(ner_tags)

df_scored = pd.DataFrame(scored)
df_scored.to_csv("../data/scored/validation.csv", index=False)

100%|██████████| 5/5 [00:41<00:00,  8.22s/it]


In [19]:
scored = defaultdict(list)

for idx, example in enumerate(tqdm(dataset["test"])):
    if (idx + 1) % 100 == 0:
        df_scored = pd.DataFrame(scored)
        df_scored.to_csv("../data/scored/test.csv", index=False)
    
    try:
        text = " ".join(example["tokens"])
        prompt_input = prompt(text)
        output = run_mistral(prompt_input)
        ner_tags = get_token_labels(text, parse_entities_promptner(output), label2index)

        scored["id"].append(example["id"])
        scored["tokens"].append(example["tokens"])
        scored["prompt"].append(prompt_input)
        scored["output"].append(output)
        scored["ner_tags"].append(ner_tags)
    except Exception as e:
        print(e)
        continue

df_scored = pd.DataFrame(scored)
df_scored.to_csv("../data/scored/test.csv", index=False)

  6%|▌         | 37/651 [03:56<1:05:32,  6.40s/it]


KeyboardInterrupt: 

In [24]:
score_ner(df_scored["ner_tags"].to_list(), dataset["test"]["ner_tags"])

ValueError: Found input variables with inconsistent numbers of samples:
[24, 61, 43, 45, 75, 29, 69, 57, 41, 34, 34, 44, 25, 38, 58, 50, 39, 22, 17, 51, 41, 58, 61, 44, 49, 51, 34, 40, 37, 44, 43, 38, 56, 46, 68, 25, 40, 55, 41, 33, 56, 51, 52, 32, 56, 25, 44, 55, 44, 76, 42, 32, 41, 39, 31, 40, 45, 45, 35, 65, 35, 40, 35, 51, 60, 83, 41, 60, 41, 44, 54, 41, 62, 73, 38, 43, 29, 52, 66, 32, 63, 44, 63, 39, 79, 35, 22, 24, 41, 25, 28, 50, 59, 75, 46, 42, 47, 40, 54, 42, 45, 50, 31, 29, 39, 68, 47, 69, 22, 73, 43, 34, 38, 55, 67, 33, 78, 79, 42, 21, 26, 31, 42, 53, 44, 53, 34, 40, 62, 38, 49, 40, 35, 76, 61, 80, 39, 48, 59, 25, 49, 38, 54, 38, 37, 42, 55, 39, 60, 55, 41, 48, 44, 84, 51, 35, 75, 56, 53, 31, 52, 25, 61, 49, 64, 36, 37, 41, 51, 37, 50, 50, 48, 70, 60, 34, 41, 27, 60, 53, 33, 61, 59, 40, 66, 55, 50, 36, 42, 43, 32, 46, 46, 47, 78, 44, 55, 57, 51, 57, 51, 44, 36, 39, 31, 51, 36, 58, 62, 39, 32, 43, 41, 74, 29, 45, 70, 47, 38, 32, 34, 29, 32, 57, 34, 58, 33, 37, 30, 58, 68, 49, 25, 35, 79, 34, 34, 36, 43, 39, 29, 34, 29, 30, 30, 33, 65, 56, 41, 26, 41, 71, 63, 36, 38, 39, 58, 79, 62, 76, 29, 41, 41, 33, 32, 45, 33, 39, 46, 35, 45, 42, 30, 22, 51, 36, 43, 76, 19, 61, 48, 37, 38, 52, 46, 40, 43, 47, 29, 45, 41, 25, 29, 69, 75, 38, 41, 46, 54, 72, 29, 53, 32, 70, 67, 41, 38, 42, 45, 51, 55, 41, 49, 47, 50, 42, 34, 49, 79, 51, 50, 46, 54, 69, 38, 40, 47, 76, 39, 61, 51, 41, 42, 37, 38, 30, 41, 26, 46, 68, 56, 22, 49, 39, 52, 34, 26, 19, 30, 41, 30, 52, 67, 55, 80, 50, 52, 40, 34, 38, 25, 46, 36, 68, 50, 54, 57, 75, 56, 34, 40, 68, 35, 84, 37, 27, 43, 37, 34, 30, 45, 50, 27, 28, 63, 44, 28, 46, 33, 21, 51, 38, 26, 37, 27, 44, 37, 40, 57, 42, 26, 29, 48, 36, 30, 47, 47, 43, 44, 44, 29, 24, 21, 76, 44, 41, 58, 41, 23, 23, 44, 55, 62, 75, 35, 27, 51, 72, 60, 45, 32, 48, 20, 35, 37, 41, 30, 38, 20, 34, 32, 48, 29, 37, 38, 34, 44, 28, 19, 67, 29, 33, 57, 46, 29, 33, 38, 43, 71, 24, 37, 63, 26, 51, 55, 47, 31, 53, 42, 36, 46, 44, 36, 34, 24, 80, 42, 38, 65, 27, 68, 73, 67, 56, 60, 42, 26, 33, 45, 43, 44, 62, 43, 26, 62, 29, 70, 54, 47, 62, 75, 71, 48, 66, 54, 35, 61, 22, 70, 62, 47, 58, 65, 53, 31, 53, 39, 27, 37, 30, 38, 52, 47, 56, 52, 51, 41, 64, 62, 53, 41, 41, 80, 47, 47, 57, 40, 37, 37, 39, 61]
[22, 48, 40, 40, 38, 29, 67, 52, 18, 34, 17, 42, 25, 32, 43, 48, 39, 20, 17, 45, 35, 58, 54, 42, 41, 47, 25, 40, 33, 44, 43, 34, 51, 42, 56, 25, 38, 51, 41, 33, 52, 48, 49, 30, 52, 22, 42, 55, 44, 72, 42, 32, 41, 39, 31, 40, 45, 45, 35, 63, 35, 40, 32, 51, 60, 79, 39, 34, 36, 40, 48, 37, 56, 67, 34, 41, 29, 52, 30, 32, 58, 44, 63, 36, 72, 33, 22, 24, 41, 25, 28, 48, 37, 73, 44, 42, 42, 40, 54, 42, 43, 44, 29, 26, 39, 64, 47, 67, 20, 71, 35, 33, 33, 54, 64, 31, 76, 72, 42, 21, 26, 17, 40, 53, 36, 51, 31, 40, 62, 36, 49, 40, 35, 74, 55, 66, 39, 46, 59, 25, 49, 14, 52, 38, 37, 17, 51, 35, 60, 26, 39, 46, 40, 80, 31, 32, 73, 56, 53, 29, 49, 22, 56, 46, 54, 36, 37, 41, 51, 37, 50, 46, 41, 59, 50, 34, 41, 25, 58, 49, 16, 57, 59, 39, 59, 55, 44, 31, 41, 18, 30, 46, 46, 36, 64, 38, 55, 57, 51, 57, 51, 44, 36, 39, 30, 51, 34, 19, 56, 35, 32, 38, 41, 72, 11, 41, 66, 47, 24, 30, 34, 25, 32, 57, 34, 55, 33, 37, 21, 33, 51, 47, 25, 35, 75, 30, 34, 19, 43, 35, 29, 19, 29, 30, 20, 18, 46, 36, 39, 26, 41, 70, 23, 36, 38, 36, 55, 75, 60, 76, 27, 41, 39, 30, 32, 45, 33, 39, 46, 15, 41, 42, 30, 22, 51, 36, 41, 32, 19, 59, 46, 29, 34, 52, 42, 40, 43, 47, 29, 43, 41, 25, 29, 56, 73, 38, 41, 46, 54, 69, 29, 53, 32, 62, 60, 36, 38, 40, 41, 49, 55, 41, 20, 45, 30, 42, 34, 49, 67, 43, 47, 46, 54, 69, 36, 37, 45, 76, 39, 59, 43, 41, 40, 37, 35, 30, 41, 26, 46, 43, 39, 22, 47, 39, 52, 33, 26, 18, 30, 41, 30, 52, 67, 55, 78, 47, 52, 38, 34, 38, 25, 44, 31, 68, 50, 42, 47, 75, 54, 34, 40, 66, 35, 81, 37, 27, 41, 34, 34, 27, 41, 47, 13, 26, 63, 40, 28, 45, 33, 18, 51, 38, 26, 37, 27, 43, 37, 40, 52, 42, 26, 29, 45, 36, 30, 47, 47, 43, 40, 42, 29, 13, 21, 76, 44, 37, 35, 37, 23, 23, 44, 55, 62, 61, 35, 14, 49, 72, 48, 45, 17, 28, 20, 32, 37, 41, 30, 34, 20, 29, 32, 33, 29, 37, 36, 18, 41, 24, 19, 67, 28, 33, 28, 46, 29, 32, 36, 43, 56, 16, 19, 63, 26, 45, 55, 47, 26, 51, 42, 36, 46, 44, 36, 34, 24, 69, 42, 38, 52, 27, 52, 62, 65, 56, 58, 42, 19, 33, 45, 39, 44, 62, 41, 26, 59, 24, 68, 54, 47, 62, 71, 65, 47, 66, 50, 33, 58, 22, 70, 62, 41, 52, 61, 53, 31, 50, 37, 27, 37, 25, 38, 52, 46, 51, 52, 50, 35, 46, 60, 53, 41, 41, 77, 41, 42, 57, 38, 37, 37, 39, 57]