In [1]:
import json
import pickle
import requests
from pathlib import Path
from urllib.parse import urlparse, unquote
from tqdm.auto import tqdm
from IPython.display import display, Markdown

In [2]:
with open("../data/normalised_titles.pkl", "rb") as f:
    normalised_title = pickle.load(f)

In [3]:
num_lines = {
    'aida-b': 4485,
    'cweb': 11116,
    'reddit-comments': 638,
    'reddit-posts': 704,
    'shadowlinks-shadow': 904,
    'shadowlinks-tail': 901,
    'shadowlinks-top': 904,
    'tweeki': 860,
    'wned-wiki': 6765
}

In [4]:
def get_matching_articles(titles):
    base_url = 'https://en.wikipedia.org/w/api.php'
    params = {
        'action': 'query',
        'format': 'json',
        'list': 'search',
        'srsearch': '',
        'srprop': 'size',
        'utf8': 1
    }

    matching_articles = {}

    s = requests.session()

    for title in tqdm(titles):
        params['srsearch'] = title
        response = s.get(base_url, params=params).json()

        if 'query' in response and 'search' in response['query']:
            search_results = response['query']['search']
            if len(search_results) > 0:
                best_match = search_results[0]
                matching_articles[title] = best_match['title']
        else:
            matching_articles[title] = "UNKNOWN"
            print(f"Error retrieving search results for '{title}'")

    return matching_articles

In [5]:
def parse_json(string):
    try:
        start = string.index("{")
        end = string.index("}")
        string = string[start:end+1]
    except ValueError:
        return {}
    else:
        return json.loads(string.replace("\n", ""))

  
def convert_wikipedia_urls(strings):
    converted_strings = []
    
    for string in strings:
        parsed_url = urlparse(string)
        path = parsed_url.path.strip("/")
        title = unquote(path.split("/")[-1].replace("_", " "))
        converted_strings.append(title)
    
    return converted_strings


def get_target_and_preds(file):
    lines = []
    with open(file, "r") as infile:
        for idx, line in enumerate(infile):
            line = json.loads(line)
            try:
                response = parse_json(line['response'])
            except json.JSONDecodeError as e:
                print(idx, line['response'], e)
            else:
                if 'candidates' in response.keys():
                    line['candidates'] = convert_wikipedia_urls(response['candidates'])
                    # line['candidates'] = get_matching_articles(line['candidates'])
                else:
                    line['candidates'] = []
                lines.append([line['title'], line['candidates']])
    return lines


def get_markdown_table(results):
    datasets = ["aida-b", "tweeki", "reddit-posts", "reddit-comments", "wned-wiki",
            "shadowlinks-tail", "shadowlinks-shadow", "shadowlinks-top"]

    res = ('| k |' + ' | '.join(datasets) + ' |' + '\n').upper()
    res += '|-' + '-|-'.join('-' * 3 for i in range(len(datasets) + 1)) + '-|' + '\n'
    for i in range(5):
        res += f'| {i + 1}' 
        for dataset in datasets:
            try:
                res += ' | ' + str(results[dataset][i])
            except KeyError:
                res += ' | - '
        res += ' |' + '\n'

    return res.replace("SHADOWLINKS", "SLINKS").replace("COMMENTS", "COMM.")

In [6]:
def get_accuracy(lines, dataset, k=5, fix=False):
    correct = 0
    for target, candidates in lines:
        target = target.replace("_", " ").lower()
        if not fix:
            candidates = [i.replace("_", " ").lower() 
                          for i in candidates][:k]
        else:
            candidates = [normalised_title.get(i, "").replace("_", " ").lower() 
                          for i in candidates][:k]

        correct += int(target in candidates)
    return round(correct/num_lines[dataset], 3)

In [7]:
files = list(Path("parametric-gpt-3").iterdir())
pred_titles = []
results = {}

for file in files:
    name = file.name.split("-", 1)[-1][:-6]
    results[name] = []
    lines = get_target_and_preds(file)
    for _, titles in lines:
        pred_titles.extend(titles)

    for k in range(1, 6):
        accuracy = get_accuracy(lines, name, k)
        results[name].append(accuracy)

results = get_markdown_table(results)
with open("results.md", "a") as f:
    f.write("# GPT-3 (Parametric)\n" + results + "\n\n")

display(Markdown(results))

| K |AIDA-B | TWEEKI | REDDIT-POSTS | REDDIT-COMM. | WNED-WIKI | SLINKS-TAIL | SLINKS-SHADOW | SLINKS-TOP |
|-----|-----|-----|-----|-----|-----|-----|-----|-----|
| 1 | 0.584 | 0.463 | 0.598 | 0.408 | 0.507 | 0.38 | 0.264 | 0.469 |
| 2 | 0.635 | 0.513 | 0.656 | 0.459 | 0.559 | 0.408 | 0.345 | 0.565 |
| 3 | 0.651 | 0.537 | 0.673 | 0.472 | 0.578 | 0.421 | 0.375 | 0.6 |
| 4 | 0.661 | 0.551 | 0.678 | 0.48 | 0.593 | 0.425 | 0.392 | 0.616 |
| 5 | 0.676 | 0.555 | 0.682 | 0.483 | 0.606 | 0.436 | 0.407 | 0.619 |


In [8]:
# import pickle
# with open("normalised_titles.pkl", "wb") as f:
#     pickle.dump(normalised_title, f)

# normalised_title = get_matching_articles(list(set(pred_titles)))

In [9]:
results = {}

for file in files:
    name = file.name.split("-", 1)[-1][:-6]
    results[name] = []
    lines = get_target_and_preds(file)
    for _, titles in lines:
        pred_titles.extend(titles)

    for k in range(1, 6):
        accuracy = get_accuracy(lines, name, k, fix=True)
        results[name].append(accuracy)

results = get_markdown_table(results)
with open("results.md", "a") as f:
    f.write("# GPT-3 (Parametric, Hallucinations Fixed)\n" + results + "\n\n")

display(Markdown(results))

| K |AIDA-B | TWEEKI | REDDIT-POSTS | REDDIT-COMM. | WNED-WIKI | SLINKS-TAIL | SLINKS-SHADOW | SLINKS-TOP |
|-----|-----|-----|-----|-----|-----|-----|-----|-----|
| 1 | 0.71 | 0.698 | 0.697 | 0.517 | 0.64 | 0.681 | 0.386 | 0.647 |
| 2 | 0.76 | 0.756 | 0.744 | 0.58 | 0.693 | 0.728 | 0.487 | 0.744 |
| 3 | 0.776 | 0.785 | 0.76 | 0.603 | 0.715 | 0.759 | 0.529 | 0.782 |
| 4 | 0.786 | 0.801 | 0.773 | 0.619 | 0.733 | 0.768 | 0.559 | 0.796 |
| 5 | 0.802 | 0.805 | 0.778 | 0.63 | 0.748 | 0.786 | 0.574 | 0.801 |
