In [1]:
import json
import openai
from pathlib import Path
from tqdm.auto import tqdm
from utils import get_tagged_context, get_neighboring_sentences

In [2]:
test_files = list(Path("test-data").iterdir())
test_files

[WindowsPath('test-data/test_aida-b.jsonl'),
 WindowsPath('test-data/test_cweb.jsonl'),
 WindowsPath('test-data/test_reddit-comments.jsonl'),
 WindowsPath('test-data/test_reddit-posts.jsonl'),
 WindowsPath('test-data/test_shadowlinks-shadow.jsonl'),
 WindowsPath('test-data/test_shadowlinks-tail.jsonl'),
 WindowsPath('test-data/test_shadowlinks-top.jsonl'),
 WindowsPath('test-data/test_tweeki.jsonl'),
 WindowsPath('test-data/test_wned-wiki.jsonl')]

In [5]:
lines = []
file = test_files[0]
print(file)

with open(file, 'r') as infile:
    for entry in tqdm(infile):
        entry = json.loads(entry)
        for span, title, id in zip(entry["index"], entry['wikipedia_ids'], entry["wikipedia_titles"]):
            context = get_tagged_context(entry["text"], span)
            context = get_neighboring_sentences(context)
            lines.append((context, title, id))

print(len(lines))
print(*lines[:5], sep="\n")

test-data\test_aida-b.jsonl


0it [00:00, ?it/s]

4485
('SOCCER - <a>JAPAN</a> GET LUCKY WIN , CHINA IN SURPRISE DEFEAT . Nadim Ladki AL-AIN , United Arab Emirates 1996-12-06 Japan began the defence of their Asian Cup title with a lucky 2-1 win against Syria in a Group C championship match on Friday .', 993546, 'Japan_national_football_team')
('SOCCER - JAPAN GET LUCKY WIN , <a>CHINA</a> IN SURPRISE DEFEAT . Nadim Ladki AL-AIN , United Arab Emirates 1996-12-06 Japan began the defence of their Asian Cup title with a lucky 2-1 win against Syria in a Group C championship match on Friday .', 887850, 'China_national_football_team')
('SOCCER - JAPAN GET LUCKY WIN , CHINA IN SURPRISE DEFEAT . Nadim Ladki <a>AL-AIN</a> , United Arab Emirates 1996-12-06 Japan began the defence of their Asian Cup title with a lucky 2-1 win against Syria in a Group C championship match on Friday . But China saw their luck desert them in the second match of the group , crashing to a surprise 2-0 defeat to newcomers Uzbekistan .', 212131, 'Al_Ain')
('SOCCER - JAPA

In [6]:
prompt = """Context: %s
Instruction: Generate a JSON with schema`{"candidates": list[str]}` listing at most 5 Wikipedia articles the text inside the anchor tag could be referring to, ordered by confidence. 
"""

In [7]:
idx = 0
context, _, target = lines[idx]
print(prompt % context)

Context: SOCCER - <a>JAPAN</a> GET LUCKY WIN , CHINA IN SURPRISE DEFEAT . Nadim Ladki AL-AIN , United Arab Emirates 1996-12-06 Japan began the defence of their Asian Cup title with a lucky 2-1 win against Syria in a Group C championship match on Friday .
Instruction: Generate a JSON with schema`{"candidates": list[str]}` listing at most 5 Wikipedia articles the text inside the anchor tag could be referring to, ordered by confidence. 



In [None]:
start_idx = 0
missing = []
outfile = "../evaluation/parametric-gpt-3/responses-" + file.name.split("_")[-1]

In [None]:
with open(outfile, "a") as f:
    for idx, (context, id, title) in tqdm(enumerate(lines), total=len(lines)):
        if idx < start_idx:
            continue
        try:
            response = openai.Completion.create(
                    model="text-davinci-003", 
                    prompt=prompt%context,
                    temperature=0,
                    max_tokens=1024,
                    top_p=1,
                )
        except openai.error.APIError as e:
            missing.append(idx)
            print(f"OpenAI API returned an API Error: {e}")
            pass
        else:
            entry = {
                "context": context, 
                "id": id, 
                "title": title, 
                "response": response.choices[0].text
            }
            
            json.dump(entry, f)
            f.write("\n")
            print("Index   :", idx)
            print("Context :", context)
            print("title   :", title)
            print("response:", response.choices[0].text)
            print("-" * 50)
        