> This Notebook is meant to be run on Kaggle

# Load Model

In [None]:
%%capture
!export GITHUB_ACTIONS=true
!pip install transformers
!pip install auto-gptq gdown

In [1]:
import warnings
warnings.filterwarnings("ignore")

In [2]:
from transformers import AutoTokenizer, logging
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig

In [3]:
model_name_or_path = "TheBloke/Llama-2-13B-chat-GPTQ"

tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)

model = AutoGPTQForCausalLM.from_quantized(
    model_name_or_path,
    revision="gptq-8bit-64g-actorder_True",
    use_safetensors=True,
    trust_remote_code=True,
    device_map="auto",
    use_triton=False,
    quantize_config=None
)

  0%|          | 0/1523 [00:00<?, ?w/s]

# Load Candidate List

In [None]:
%%capture
!mkdir /kaggle/temp
!mkdir /kaggle/temp/zelda-test
!gdown -q -O /kaggle/temp/zelda-test.zip --fuzzy https://drive.google.com/file/d/1Qi19SfGoztNrEx8opQFd5kM1NIFdvxob/view?usp=sharing
!unzip /kaggle/temp/zelda-test.zip -d /kaggle/temp/zelda-test

In [4]:
import json
import pickle
import re
from pathlib import Path
from tqdm.auto import tqdm

In [5]:
punctuation_remover = re.compile(r"[\W]+")
mention_entities_counter = None

# load candidate list
with open('/kaggle/temp/zelda-test/zelda_mention_entities_counter.pickle', 'rb') as handle:
    mention_entities_counter = pickle.load(handle)

# candidate retriever
def get_candidates(mention, limit=None):
    candidates = mention_entities_counter.get(mention)
    if candidates is None:
        mention = mention.replace(' ', '').lower()
        candidates = mention_entities_counter.get(mention)
        if candidates is None:
            mention = punctuation_remover.sub("", mention)
            candidates = mention_entities_counter.get(mention)

    if candidates is None:
        return

    candidates = list(candidates.items())
    candidates.sort(key=lambda x: x[1], reverse=True)
    if limit is not None and len(candidates) > limit:
        candidates = candidates[:limit]

    candidates = [i[0].replace(" ", "_") for i in candidates]
    return candidates

# Prefix Constrained BeamSearch

In [6]:
def restrict_phrases(initial_input_ids, candidates):
    """Restricts the answer to a fixed set of allowed phrases"""
    def prefix_allowed_tokens(batch_id, input_ids):
        # Get the answer so far
        decoded = tokenizer.decode(input_ids[initial_input_ids.shape[1]:]).replace(" ", "")

        # How could we continue this into a phrase
        phrases = [candidate[len(decoded):] for candidate in candidates
                   if candidate.startswith(decoded)]

        # What token comes next?
        next_tokens = []
        for p in phrases:
            if p:
                start_token = tokenizer.encode(p, add_special_tokens=False)
                # remove empty string token if first
                if 29871 in start_token:
                    start_token.remove(29871)
                next_tokens.append(start_token[0])
            # if end of phrase, add eos
            else:
                next_tokens.append(tokenizer.eos_token_id)
        # if no next phrases, add eos
        next_tokens = list(set(next_tokens)) if next_tokens else [tokenizer.eos_token_id]
        # print(repr(decoded),phrases, next_tokens, tokenizer.batch_decode(next_tokens))
        return next_tokens

    return prefix_allowed_tokens


def generate(input_ids, **kwargs):
    output = model.generate(
        inputs = input_ids,
        return_dict_in_generate=True,
        output_scores=True,
        **kwargs
    )
    # print(tokenizer.decode(input_ids[0]))
    # print(f"Output:\n" + 100 * '-')
    for seq, score in zip(output.sequences, output.sequences_scores):
        decoded = tokenizer.decode(seq[input_ids.shape[1]:], skip_special_tokens=True).replace(" ", "")
        return (score.item(), decoded)
    # else:
    #     for seq in output.sequences:
    #         decoded = tokenizer.decode(seq[input_ids.shape[1]:], skip_special_tokens=True).replace(" ", "")
    #         print(repr(decoded))

In [7]:
prompt = '''<s>[INST] <<SYS>>
You are a helpful, respectful and honest assistant.
<</SYS>>

In the following context

%s

What is the referent Wikipedia article of "%s" [/INST] https://en.wikipedia.org/wiki/'''

def disambiguate(context, mention, num_beams=5):
    candidates = get_candidates(mention)

    if candidates == []:
        return None, None, []

    if len(candidates) == 1:
        return None, candidates[0], candidates

    input_ids = tokenizer(prompt % (context, mention), return_tensors='pt').input_ids.cuda()

    max_new_tokens = max([len(i) for i in tokenizer(candidates, add_special_tokens=False).input_ids])

    score, result = generate(
        input_ids,
        num_beams = num_beams,
        # num_beam_groups = 2,
        # diversity_penalty = 0.1,
        do_sample = False,
        # temperature = 0.3,
        # top_p = 0.8,
        # top_k = 20,
        num_return_sequences = 1,
        # early_stopping = True,
        max_new_tokens = max_new_tokens + 1,
        remove_invalid_values = True,
        prefix_allowed_tokens_fn = restrict_phrases(input_ids, candidates),
    )

    return score, result, candidates

# Load Test Data

In [8]:
def get_mention_idx(sentences):
    for idx, s in enumerate(sentences):
        if "<a>" in s:
            return idx

def get_neighboring_sentences(text, n = 1):
    sentences = re.split(r'(?<=[.!?])\s+', text)  # Split the text into sentences
    idx = get_mention_idx(sentences)
    start, end = max(idx - n, 0), idx + n + 1
    return " ".join(sentences[start:end])

def get_tagged_context(text, span):
    s, e = span
    text = text[:s] + "<a>" + text[s:e] + "</a>" + text[e:]
    return text

In [9]:
test_files = [f for f in Path("/kaggle/temp/zelda-test").iterdir() 
              if f.name.endswith("jsonl")]

for idx, f in enumerate(test_files):
    print(idx, f.name)

0 test_cweb.jsonl
1 test_shadowlinks-shadow.jsonl
2 test_shadowlinks-tail.jsonl
3 test_tweeki.jsonl
4 test_wned-wiki.jsonl
5 test_reddit-comments.jsonl
6 test_aida-b.jsonl
7 test_shadowlinks-top.jsonl
8 test_reddit-posts.jsonl


In [37]:
lines = []
file = test_files[6]
print(file)

with open(file, 'r') as handle:
    for entry in handle:
        entry = json.loads(entry)
        for span, title in zip(entry["index"], entry["wikipedia_titles"]):
            start, end = span
            context = get_tagged_context(entry["text"], span)
            context = get_neighboring_sentences(context, 2)
            lines.append((context, entry["text"][start:end], title))

total = len(lines)
print(total)
print(*lines[:5], sep="\n")

/kaggle/temp/zelda-test/test_aida-b.jsonl
4485
('SOCCER - <a>JAPAN</a> GET LUCKY WIN , CHINA IN SURPRISE DEFEAT . Nadim Ladki AL-AIN , United Arab Emirates 1996-12-06 Japan began the defence of their Asian Cup title with a lucky 2-1 win against Syria in a Group C championship match on Friday . But China saw their luck desert them in the second match of the group , crashing to a surprise 2-0 defeat to newcomers Uzbekistan .', 'JAPAN', 'Japan_national_football_team')
('SOCCER - JAPAN GET LUCKY WIN , <a>CHINA</a> IN SURPRISE DEFEAT . Nadim Ladki AL-AIN , United Arab Emirates 1996-12-06 Japan began the defence of their Asian Cup title with a lucky 2-1 win against Syria in a Group C championship match on Friday . But China saw their luck desert them in the second match of the group , crashing to a surprise 2-0 defeat to newcomers Uzbekistan .', 'CHINA', 'China_national_football_team')
('SOCCER - JAPAN GET LUCKY WIN , CHINA IN SURPRISE DEFEAT . Nadim Ladki <a>AL-AIN</a> , United Arab Emirate

In [38]:
outfile = "/kaggle/working/responses-" + file.name.split("_")[-1]

with open(outfile, "r") as f:
    l2 = set()
    for line in tqdm(f):
        try:
            line = json.loads(line)
        except json.JSONDecodeError:
            pass
        else:
            l2.add(line["context"]) 
    missing = []
    for idx, i in enumerate(lines):
        if i[0] not in l2:
            missing.append(idx)
            
lines = [lines[i] for i in missing]
len(lines)

def shorten(context, l = 300):
    start, end = context.index("<a>"), context.index("</a>")
    context = context[max(0, start - l): end + l]
    return context
    
lines = [(shorten(context), mention, target) for (context, mention, target) in lines]

0it [00:00, ?it/s]

In [39]:
start_idx = 0
missing = []
outfile = "/kaggle/working/responses-2-" + file.name.split("_")[-1]

In [None]:
with open(outfile, "a") as f:
    for idx, (context, mention, title) in tqdm(enumerate(lines), total=len(lines)):

        if idx < start_idx:
            continue

        title = "https://en.wikipedia.org/wiki/" + title.replace(" ", "_")

        try:
            score, result, candidates = disambiguate(context, mention, num_beams=5)
            if result not in candidates:
                result = None
        except Exception as e:
            missing.append(idx)
            print("Error getting", idx, e)
        else :
            if result is not None:
                entry = {
                    "context": context,
                    "title": title,
                    "result": "https://en.wikipedia.org/wiki/" + result,
                    "candidates": candidates,
                }

                json.dump(entry, f)
                f.write("\n")
                print("Index  :", idx)
                print("title  :", title)
                print("result :", "https://en.wikipedia.org/wiki/" + result)
                print("N-Cands:", len(candidates))
        
        print("-" * 50)
        
        start_idx += 1