In [1]:
%%capture
!export GITHUB_ACTIONS=true
!pip install auto-gptq
!pip install duckduckgo_search gdown

# Load Model

In [None]:
from transformers import AutoTokenizer, pipeline, logging
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig

model_name_or_path = "TheBloke/Llama-2-13B-chat-GPTQ"

use_triton = False

tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True)

In [None]:
model = AutoGPTQForCausalLM.from_quantized(
    model_name_or_path,
    revision="gptq-8bit-128g-actorder_False",
    use_safetensors=True,
    trust_remote_code=True,
    device_map="auto",
    use_triton=use_triton,
    quantize_config=None
)

# Setup Data

In [4]:
%%capture
!mkdir /kaggle/temp
!mkdir /kaggle/temp/zelda-test
!gdown -q -O /kaggle/temp/zelda-test.zip --fuzzy https://drive.google.com/file/d/1Qi19SfGoztNrEx8opQFd5kM1NIFdvxob/view?usp=sharing
!unzip /kaggle/temp/zelda-test.zip -d /kaggle/temp/zelda-test

In [5]:
import re
import json
import pickle
from pathlib import Path
from urllib.parse import unquote

from duckduckgo_search import DDGS
from tqdm.auto import tqdm

In [6]:
def ddg_answers(mention, num_res=20, desc_limit=300):
    results = ""
    urls = []
    with DDGS() as ddgs:
        for r in ddgs.answers(mention):
            title = unquote(r['url'].split('/')[-1])
            url = "https://en.wikipedia.org/wiki/" + title
            if url in urls:
                continue
            results += f"\n{len(urls)+1}. {url} - {r['text'][:desc_limit]}"
            urls.append(url)
            
            if len(urls) == num_res:
                break
                
    return results, urls


def ddg_search(term, num_res=10, desc_limit=300):
    results = "results:\n"
    urls = []

    with DDGS() as ddgs:
        for idx, r in enumerate(ddgs.text(f'{term} site:en.wikipedia.org', safesearch="Off")):
            results += f"\n{idx+1}. {r['href']} - {r['body'][:desc_limit]} ..."
            urls.append(r["href"])
            if len(urls) == num_res:
                break
        
        return results, urls

In [7]:
def get_mention_idx(sentences):
    for idx, s in enumerate(sentences):
        if "<a>" in s:
            return idx

def get_neighboring_sentences(text, n = 1):
    sentences = re.split(r'(?<=[.!?])\s+', text)  # Split the text into sentences
    idx = get_mention_idx(sentences)
    start, end = max(idx - n, 0), idx + n + 1
    return " ".join(sentences[start:end])

def get_tagged_context(text, span):
    s, e = span
    text = text[:s] + "<a>" + text[s:e] + "</a>" + text[e:]
    return text

def propocess_data(file):
    lines = []
    with open(file, 'r') as handle:
        for entry in handle:
            entry = json.loads(entry)
            for span, title in zip(entry["index"], entry["wikipedia_titles"]):
                start, end = span
                context = get_tagged_context(entry["text"], span)
                context = get_neighboring_sentences(context, 2)
                lines.append((context, entry["text"][start:end], title))

    return lines

In [8]:
# load test files
test_files = [file for file in Path("/kaggle/temp/zelda-test").iterdir() if file.name.endswith(".jsonl")]
test_files

[PosixPath('/kaggle/temp/zelda-test/test_cweb.jsonl'),
 PosixPath('/kaggle/temp/zelda-test/test_shadowlinks-shadow.jsonl'),
 PosixPath('/kaggle/temp/zelda-test/test_shadowlinks-tail.jsonl'),
 PosixPath('/kaggle/temp/zelda-test/test_tweeki.jsonl'),
 PosixPath('/kaggle/temp/zelda-test/test_wned-wiki.jsonl'),
 PosixPath('/kaggle/temp/zelda-test/test_reddit-comments.jsonl'),
 PosixPath('/kaggle/temp/zelda-test/test_aida-b.jsonl'),
 PosixPath('/kaggle/temp/zelda-test/test_shadowlinks-top.jsonl'),
 PosixPath('/kaggle/temp/zelda-test/test_reddit-posts.jsonl')]

In [9]:
prompt_template = '''<s>[INST] <<SYS>>
You are a helpful, respectful and honest assistant.
<</SYS>>

In the following context

<context>%s</context>

Determine which of the following is the referent Wikipedia article of the text inside the anchor tag:
%s

provide your answer as <answer>Wikipedia URL</answer>[/INST] <answer>https://en.wikipedia.org/wiki/'''


# putting it all together
def disambiguate(context, mention, temp=0.1, top_p=0.3, max_new_t=125, debug=False):
    def extract_answer(text):
        start = text.rfind("<answer>")
        end = text.rfind("</answer>")
        return text[start+8:end]
    
    # generate candidates
    candidates, urls = ddg_search(mention)
    if urls == []:
        return None, None
    
    # generate answer
    prompt = prompt_template % (context.strip(), candidates)
    input_ids = tokenizer(prompt, return_tensors='pt').input_ids.cuda()
    output = model.generate(inputs=input_ids, temperature=temp, top_p=top_p, max_new_tokens=max_new_t)
    output = tokenizer.decode(output[0])
    if debug:
        print(output)
    
    answer = extract_answer(output)
    output = output.split("[/INST]")[-1].strip()
    return answer, urls, output

In [10]:
file = test_files[6]
lines = propocess_data(file)

start_idx = 0
missing = []
outfile = "/kaggle/working/" + file.name.split("_")[-1]

In [11]:
# idx = 7
# context, mention, target = lines[idx]
# answer, urls, output = disambiguate(context, mention)
# target, answer, output

In [None]:
with open(outfile, "w") as f:
    for idx, (context, mention, title) in tqdm(enumerate(lines), total=len(lines)):
        
        if idx < start_idx:
            continue
        
        title = "https://en.wikipedia.org/wiki/" + title.replace(" ", "_")

        try:
            result, candidates, output = disambiguate(context, mention)
        except Exception as e:
            missing.append(idx)
            print("Error getting", idx, e)
        else :
            if result is not None:
                entry = {
                    "context": context, 
                    "target": title,
                    "result": result,
                    "output": output,
                    "candidates": candidates,
                }
            
            json.dump(entry, f)
            f.write("\n")
            print("Index  :", idx)
            print("Target :", title)
            print("Result :", result)
            print("Output :", output)
        
        print("-" * 50)
        start_idx += 1