In [60]:
import sys
sys.path.append("../src")
from ner_post_processing import parse_entities_promptner, get_token_labels
import pandas as pd
from tqdm import tqdm
from datasets import load_dataset
from collections import defaultdict

In [62]:
import urllib3
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)

dataset = load_dataset("DFKI-SLT/cross_ner", "politics")

Found cached dataset cross_ner (/Users/vkmr/.cache/huggingface/datasets/DFKI-SLT___cross_ner/politics/1.1.0/e1d1a6ac35c3ee9d62d89789aad42c65e8266eb7d75bcba812d59e45639c005e)


  0%|          | 0/3 [00:00<?, ?it/s]

In [63]:
class_labels = dataset["validation"].features["ner_tags"].feature.names
index2label = {i: label for i, label in enumerate(class_labels)}
label2index = {v: k for k, v in index2label.items()}

In [64]:
import evaluate

metric = evaluate.load("seqeval")

class_labels = dataset["validation"].features["ner_tags"].feature.names
index2label = {i: label for i, label in enumerate(class_labels)}
label2index = {v: k for k, v in index2label.items()}

def score_ner(prediction_batch,gold_batch):
    labelled_predicions = []
    for prediction in prediction_batch:
        labelled_predicions.append([index2label[i] for i in prediction])
    labelled_gold = []
    for gold in gold_batch:
        labelled_gold.append([index2label[i] for i in gold])
    return metric.compute(
    predictions=labelled_predicions, 
    references=labelled_gold
    )

In [65]:
df = pd.DataFrame(dataset["test"])
instruction = "An entity is a person (person), organization (organization), politician (politician), political party (politicalparty), event (event), election (election), country (country), location (location), or other political entity (misc). Dates, times, abstract concepts, adjectives, and verbs are not entities.\n\nFor each potential entity in the text, determine if it is an entity and, if so, its type. Provide the reason for your decision. Format your response as a YAML list, with each item containing the following fields:\n\nspan: The text span of the potential entity.\nentity_type: The type of the entity (person, organization, politician, politicalparty, event, election, country, location, misc) or false if not an entity.\nreason: A brief explanation of why the span is or is not an entity.",
df["inference_prompt"] = df.apply(lambda x:f"### INSTRUCTION: {instruction} ### PARAGRAPH: {x['tokens']}  ### TAG_SPANS: ",axis=1)

In [66]:
import requests
import concurrent.futures
import time

def generate_completion(model_id, prompt, retries=3, delay=1):
    url = "https://api.fireworks.ai/inference/v1/completions"
    headers = {
        "Authorization": "Bearer FFJxiShwuQO0MyRm7ynfQnDkWdZYosEBIOVEf2AbIyzyAXre",
        "Content-Type": "application/json"
    }
    data = {
        "model": f"accounts/vaibhavk992-6442ca/models/{model_id}",
        "prompt": prompt,
        "max_tokens": 32768,
    }

    for attempt in range(retries):
        try:
            response = requests.post(url, headers=headers, json=data)
            if response.status_code == 200:
                result = response.json()
                return result
            else:
                raise Exception(f"Request failed with status code: {response.status_code}")
        except Exception as e:
            if attempt < retries - 1:
                print(f"Request failed. Retrying in {delay} second(s)...")
                time.sleep(delay)
            else:
                raise e

def process_row(row):
    prompt = row["inference_prompt"]
    result = generate_completion("995b5cf4a000477f87032f0edb0b22ce", prompt)
    return result

with concurrent.futures.ThreadPoolExecutor(max_workers=100) as executor:
    futures = []
    for _, row in df.iterrows():
        future = executor.submit(process_row, row)
        futures.append(future)

    results = [future.result() for future in concurrent.futures.as_completed(futures)]

df["model_raw_output"] = results
df["filtered_output"] = df["model_raw_output"].apply(lambda x: x["choices"][0]["text"])

In [67]:
import json
def apply_safe_json(x):
    try:
        return json.loads(x)
    except:
        return []
df["json_output"] = df["filtered_output"].apply(apply_safe_json)

In [70]:
scored = defaultdict(list)

def parse_entities_finetune(output):
    entities = []
    for entry in output:
        try:
            entity = entry["span"]
            is_entity = entry["entity_type"] != 'false'
            reasoning = entry["reason"]
            tag = entry["entity_type"]
            entities.append((entity, is_entity, reasoning, tag))
        except:
            continue
    return entities


for idx, example in tqdm(df.iterrows()):
    try:
        text = " ".join(example["tokens"])
        ner_tags = get_token_labels(text, parse_entities_finetune(example["json_output"]), label2index)
        scored["id"].append(example["id"])
        scored["tokens"].append(example["tokens"])
        scored["ner_tags"].append(ner_tags)
    except Exception as e:
        print(e)
        continue

df_scored = pd.DataFrame(scored)
score_ner(df_scored["ner_tags"].to_list(), dataset["test"]["ner_tags"])

651it [00:00, 14588.61it/s]
  _warn_prf(average, modifier, msg_start, len(result))


{'country': {'precision': 0.3157894736842105,
  'recall': 0.014354066985645933,
  'f1': 0.027459954233409606,
  'number': 418},
 'election': {'precision': 0.06666666666666667,
  'recall': 0.002304147465437788,
  'f1': 0.004454342984409799,
  'number': 434},
 'event': {'precision': 0.2,
  'recall': 0.020512820512820513,
  'f1': 0.037209302325581395,
  'number': 195},
 'location': {'precision': 0.75,
  'recall': 0.015025041736227046,
  'f1': 0.02945990180032733,
  'number': 599},
 'misc': {'precision': 0.13333333333333333,
  'recall': 0.007751937984496124,
  'f1': 0.014652014652014652,
  'number': 258},
 'organisation': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 513},
 'person': {'precision': 0.04,
  'recall': 0.002824858757062147,
  'f1': 0.005277044854881266,
  'number': 354},
 'politicalparty': {'precision': 0.1875,
  'recall': 0.0062959076600209865,
  'f1': 0.012182741116751269,
  'number': 953},
 'politician': {'precision': 0.3333333333333333,
  'recall': 0.0061855670103