In [None]:
import time
import json
import pickle
import re
from pathlib import Path

import openai
from tqdm.auto import tqdm

from utils import get_tagged_context, get_neighboring_sentences, get_wikipedia_link 

In [None]:
with open('../data/zelda_mention_entities_counter.pickle', 'rb') as handle:
    mention_entities_counter = pickle.load(handle)

with open('../data/extracts.pkl', 'rb') as handle:
    entity_description = pickle.load(handle)

In [None]:
punctuation_remover = re.compile(r"[\W]+")

def get_candidates(mention, limit=10):
    candidates = mention_entities_counter.get(mention)
    if candidates is None:
        mention = mention.replace(' ', '').lower()
        candidates = mention_entities_counter.get(mention)
        if candidates is None:
            mention = punctuation_remover.sub("", mention)
            candidates = mention_entities_counter.get(mention)
    
    if candidates is None:
        return
    
    candidates = list(candidates.items())
    candidates.sort(key=lambda x: x[1], reverse=True)
    if len(candidates) > limit:
        candidates = candidates[:limit]

    candidates = [i[0] for i in candidates]
    return candidates

In [None]:
def format_candidates(candidates):
    result = "Candidates:"
    template = "https://en.wikipedia.org/wiki/%s"
    for idx, candidate in enumerate(candidates):
        desc = entity_description.get(candidate.replace(" ", "_"), " ")
        candidate = template % candidate.replace(" ", "_")
        result += f"\n{idx +1}. {candidate} : {desc[:150]}..."
    return result

In [None]:
prompt = """In the following context

<context>%s</context>

Determine which of the following is the referent Wikipedia article of the text inside the anchor tag. 

%s

provide your answer as `<answer>Wikipedia URL</answer>`"""


def disambiguate(context, mention, model="gpt-3.5-turbo", temperature=0.2, max_tokens=200, top_p=0.15, debug=False):
    # messages
    messages = []

    # get candidates
    candidates = get_candidates(mention)
    if candidates is None:
        print(f"No candidates for {mention}")
        return None, []
    elif len(candidates) == 1:
        print("INFO Only candidate.")
        return candidates[0], candidates
        
    messages.append({"role": "user", "content": prompt % (context, format_candidates(candidates))})

    if debug:
        print(messages[-1]["content"])

    candidates = ["https://en.wikipedia.org/wiki/" + i.replace(" ", "_") for i in candidates]

    response = openai.ChatCompletion.create(
        model=model,
        messages=messages,
        temperature=temperature, 
        max_tokens=max_tokens, 
        top_p=top_p,
    )

    content = response["choices"][0]["message"]["content"]

    # Extract answer link
    link = get_wikipedia_link(content)
    if link is None:
        print("Content:", content)
        raise ValueError("Model did not respond with an answer")
    else:
        return link, candidates

In [None]:
test_files = list(Path("test-data").iterdir())
test_files

In [None]:
lines = []
file = test_files[0]
print(file)

with open(file, 'r') as handle:
    for entry in handle:
        entry = json.loads(entry)
        for span, title in zip(entry["index"], entry["wikipedia_titles"]):
            start, end = span
            context = get_tagged_context(entry["text"], span)
            context = get_neighboring_sentences(context, 2)
            lines.append((context, entry["text"][start:end], title))

total = len(lines)
print(total)
print(*lines[:5], sep="\n")

In [None]:
idx = 0
print("Context:", lines[idx][0])
print("Mention:", lines[idx][1])
result, candidates = disambiguate(lines[idx][0], lines[idx][1], debug=True)
result, lines[idx][-1], candidates

In [None]:
start_idx = 0
missing = []
outfile = "../evaluation/el-with-ek-gpt35-zelda/responses-" + file.name.split("_")[-1]

In [None]:
with open(outfile, "a") as f:
    for idx, (context, mention, title) in tqdm(enumerate(lines), total=len(lines)):
        
        if idx < start_idx:
            continue
        
        title = "https://en.wikipedia.org/wiki/" + title.replace(" ", "_")
        
        try:
            start_time = time.time()
            result, candidates = disambiguate(context, mention)
            end_time = time.time()
        except Exception as e:
            missing.append(idx)
            print("Error getting", idx, e)
        else :
            if result is not None:
                entry = {
                    "context": context, 
                    "title": title,
                    "result": result, 
                    "candidates": candidates,
                }

                json.dump(entry, f)
                f.write("\n")
                print("Index  :", idx)
                print("title  :", title)
                print("result :", result)
                print("-" * 50)

                # wait = max(21 - (end_time - start_time), 0)
                # if len(candidates) > 1:
                #     time.sleep(wait)
        
        start_idx += 1

In [None]:
## Get missing lines
# with open(outfile, "r") as f:
#     l2 = set()
#     for line in tqdm(f):
#         try:
#             line = json.loads(line)
#         except json.JSONDecodeError:
#             pass
#         else:
#             l2.add(line["context"]) 
#     missing = []
#     for idx, i in enumerate(lines):
#         if i[0] not in l2:
#             missing.append(idx)

# lines = [lines[i] for i in missing]
# missing = []
# start_idx = 0
# len(lines)