In [1]:
import json
import requests
from pathlib import Path
from tqdm.auto import tqdm
from urllib.parse import unquote
from IPython.display import display, Markdown
from utils import is_same_page, get_unique_lines, get_markdown_table

In [2]:
num_lines = {
    'aida-b': 4485,
    'cweb': 11116,
    'reddit-comments': 638,
    'reddit-posts': 704,
    'shadowlinks-shadow': 904,
    'shadowlinks-tail': 901,
    'shadowlinks-top': 904,
    'tweeki': 860,
    'wned-wiki': 6765
}

In [3]:
results_file = open("results.md", "w")

for folder in ["parametric-llama-2", "el-with-ek-gpt35-ddg", "el-with-ek-gpt35-zelda", "el-with-ek-llama-2-ddg"]:
    outfile = open("incorrect-links/" + folder + ".jsonl", "w")
    files = [f for f in Path(folder).iterdir() if f.name.endswith("jsonl")]
    results = {}
    
    for file in tqdm(files):
        with open(file, "r") as f:
            dataset = file.name.split("-", 1)[-1][:-6]
            lines = get_unique_lines(file)
            correct = 0
            correct_disamb = 0
            target_present = 0
            target_present_multi = 0
            total = num_lines[dataset]
            session = requests.session()
            
            for idx, line in enumerate(lines):
                target = line["title"].lower()
                result = unquote(line["result"].lower())
                if not result.startswith("https://"):
                    result = "https://en.wikipedia.org/wiki/" + result.replace(" ", "_")

                candidates = []
                for candidate in line["candidates"]:
                    if not candidate.startswith("https://"):
                        candidate = "https://en.wikipedia.org/wiki/" + candidate.replace(" ", "_")
                    candidates.append(unquote(candidate).lower())

                # count if target in candidates for CL-Recall
                if target not in candidates:
                    continue
                
                target_present += 1
                if len(candidates) > 1:
                    target_present_multi += 1
                
                # check if target is the result
                if result == target or is_same_page(line["result"][30:], line["title"][30:], session):
                    correct += 1
                    if len(candidates) > 1:
                        correct_disamb += 1

                # else record it as incorrect link
                else:
                    outfile.write(json.dumps(
                        {
                            "dataset": dataset,
                            "context": line["context"].replace("\n", ' '), 
                            "target": target,
                            "result": result, 
                            "candidates": candidates,
                        }))
                    outfile.write("\n")

            results[dataset] = {
                "Accuracy" : round(correct/total, 3),
                "Accuracy (Disamb.)": round(correct_disamb/target_present_multi, 3),
                "CL-Recall": round(target_present/total, 3)
            }

            session.close()

    table = get_markdown_table(results)
    display(Markdown(table))
    outfile.close()
    results_file.write("# " + folder + "\n" + table + "\n\n")

results_file.close()

  0%|          | 0/8 [00:00<?, ?it/s]

|    |AIDA-B | TWEEKI | REDDIT-POSTS | REDDIT-COMM. | WNED-WIKI | SLINKS-TAIL | SLINKS-SHADOW | SLINKS-TOP |
|-----|-----|-----|-----|-----|-----|-----|-----|-----|
| CL-Recall | 0.873 | 0.905 | 0.974 | 0.967 | 0.98 | 0.915 | 0.537 | 0.704 |
| Accuracy | 0.609 | 0.66 | 0.764 | 0.685 | 0.606 | 0.912 | 0.253 | 0.459 |
| Accuracy (Disamb.) | 0.581 | 0.637 | 0.734 | 0.658 | 0.446 | 0.875 | 0.434 | 0.601 |


  0%|          | 0/8 [00:00<?, ?it/s]

|    |AIDA-B | TWEEKI | REDDIT-POSTS | REDDIT-COMM. | WNED-WIKI | SLINKS-TAIL | SLINKS-SHADOW | SLINKS-TOP |
|-----|-----|-----|-----|-----|-----|-----|-----|-----|
| CL-Recall | 0.825 | 0.923 | 0.973 | 0.973 | 0.715 | 0.989 | 0.614 | 0.816 |
| Accuracy | 0.682 | 0.801 | 0.891 | 0.922 | 0.626 | 0.963 | 0.365 | 0.615 |
| Accuracy (Disamb.) | 0.826 | 0.868 | 0.915 | 0.947 | 0.876 | 0.974 | 0.595 | 0.753 |


  0%|          | 0/8 [00:00<?, ?it/s]

|    |AIDA-B | TWEEKI | REDDIT-POSTS | REDDIT-COMM. | WNED-WIKI | SLINKS-TAIL | SLINKS-SHADOW | SLINKS-TOP |
|-----|-----|-----|-----|-----|-----|-----|-----|-----|
| CL-Recall | 0.824 | 0.893 | 0.964 | 0.964 | 0.878 | 0.917 | 0.532 | 0.705 |
| Accuracy | 0.73 | 0.762 | 0.866 | 0.859 | 0.766 | 0.915 | 0.417 | 0.626 |
| Accuracy (Disamb.) | 0.847 | 0.801 | 0.875 | 0.876 | 0.805 | 0.889 | 0.768 | 0.872 |


  0%|          | 0/8 [00:00<?, ?it/s]

|    |AIDA-B | TWEEKI | REDDIT-POSTS | REDDIT-COMM. | WNED-WIKI | SLINKS-TAIL | SLINKS-SHADOW | SLINKS-TOP |
|-----|-----|-----|-----|-----|-----|-----|-----|-----|
| CL-Recall | 0.835 | 0.916 | 0.983 | 0.973 | 0.737 | 0.993 | 0.632 | 0.812 |
| Accuracy | 0.659 | 0.763 | 0.818 | 0.848 | 0.577 | 0.947 | 0.273 | 0.562 |
| Accuracy (Disamb.) | 0.789 | 0.832 | 0.832 | 0.871 | 0.783 | 0.953 | 0.433 | 0.692 |
