In [None]:
import json
import time
import random
import openai
import random
from pathlib import Path
from tqdm.auto import tqdm
from utils import *

In [None]:
test_files = list(Path("test-data").iterdir())
print(*[i.name.split('_')[-1] for i in test_files], sep="\n")

In [None]:
lines = []
file = test_files[0]

with open(file, 'r') as infile:
    for entry in infile:
        entry = json.loads(entry)
        for span, title in zip(entry["index"], entry["wikipedia_titles"]):
            start, end = span
            context = get_tagged_context(entry["text"], span)
            context = get_neighboring_sentences(context, 2)
            lines.append((context, entry["text"][start:end], title))

print(len(lines))
print(*lines[:5], sep="\n")

In [None]:
prompt = """In the following context

<context>%s</context>

Determine which of the following is the referent Wikipedia article of the text inside the anchor tag:
%s

provide your answer as `<answer>Wikipedia URL</answer>`."""


def disambiguate(context, mention, model="gpt-3.5-turbo", temperature=0.2, max_tokens=200, top_p=0.3, debug=False):
    # messages
    messages = [{"role": "system", "content": "You are a helpful assistant."}]

    # get candidates
    # results, candidates = get_candidates(mention)
    results, candidates = "", []
    if candidates == []:
        results, candidates = search_ddg(mention, 10)

    if debug:
        print(prompt % (context, results))

    messages.append({"role": "user", "content": prompt % (context, results)})

    response = openai.ChatCompletion.create(
        model=model,
        messages=messages,
        temperature=temperature, 
        max_tokens=max_tokens,
        top_p=top_p,
    )
    content = response["choices"][0]["message"]["content"]
    
    if debug:
        print(content)
    
    # Extract answer link
    link = get_wikipedia_link(content)
    if link is None:
        print("Content:", content)
        raise ValueError("Model did not respond with an answer")
    else:
        return link, candidates

In [None]:
idx = 0
print("Context:", lines[idx][0])
print("Mention:", lines[idx][1])
result, candidates = disambiguate(lines[idx][0], lines[idx][1], debug=True)
lines[idx][-1], result, candidates

In [None]:
start_idx = 2177
missing = []
outfile = "../evaluation/el-with-ek-llama-2-ddg/responses-" + file.name.split("_")[-1]

In [None]:
lines = lines[:len(lines)//2]

In [None]:
with open(outfile, "a") as f:
    for idx, (context, mention, title) in tqdm(enumerate(lines), total=len(lines)):
        
        if idx < start_idx:
            continue
        
        title = "https://en.wikipedia.org/wiki/" + title.replace(" ", "_")
        
        try:
            start_time = time.time()
            result, candidates = disambiguate(context, mention)
        except Exception as e:
            missing.append(idx)
            print("Error getting", idx, e)
        else :
            entry = {
                "context": context, 
                "title": title,
                "result": result, 
                "candidates": candidates,
            }

            json.dump(entry, f)
            f.write("\n")
            print("Index  :", idx)
            print("title  :", title)
            print("result :", result)
            print("-" * 50)
            
            wait = max(21 - (time.time() - start_time), 0)
            time.sleep(wait)
        
        start_idx += 1