In [1]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
tokenizer = AutoTokenizer.from_pretrained("apoorvumang/kgt5-wikikg90mv2")
model = AutoModelForSeq2SeqLM.from_pretrained("apoorvumang/kgt5-wikikg90mv2")

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=242083771.0, style=ProgressStyle(descri…




In [28]:
import torch

def getScores(ids, scores, pad_token_id):
    """get sequence scores from model.generate output"""
    scores = torch.stack(scores, dim=1)
    log_probs = torch.log_softmax(scores, dim=2)
    # remove start token
    ids = ids[:,1:]
    # gather needed probs
    x = ids.unsqueeze(-1).expand(log_probs.shape)
    needed_logits = torch.gather(log_probs, 2, x)
    final_logits = needed_logits[:, :, 0]
    padded_mask = (ids == pad_token_id)
    final_logits[padded_mask] = 0
    final_scores = final_logits.sum(dim=-1)
    return final_scores.cpu().detach().numpy()

def topkSample(input, model, tokenizer, 
                num_samples=5,
                num_beams=1,
                max_output_length=30):
    tokenized = tokenizer(input, return_tensors="pt")
    out = model.generate(**tokenized,
                        do_sample=True,
                        num_return_sequences = num_samples,
                        num_beams = num_beams,
                        eos_token_id = tokenizer.eos_token_id,
                        pad_token_id = tokenizer.pad_token_id,
                        output_scores = True,
                        return_dict_in_generate=True,
                        max_length=max_output_length,)
    out_tokens = out.sequences
    out_str = tokenizer.batch_decode(out_tokens, skip_special_tokens=True)
    out_scores = getScores(out_tokens, out.scores, tokenizer.pad_token_id)
    
    pair_list = [(x[0], x[1]) for x in zip(out_str, out_scores)]
    sorted_pair_list = sorted(pair_list, key=lambda x:x[1], reverse=True)
    return sorted_pair_list

def greedyPredict(input, model, tokenizer):
    input_ids = tokenizer([input], return_tensors="pt").input_ids
    out_tokens = model.generate(input_ids)
    out_str = tokenizer.batch_decode(out_tokens, skip_special_tokens=True)
    return out_str[0]

In [29]:
# an example from validation set that the model predicts correctly
# you can try your own examples here. what's your noble title?
input = "Sophie Valdemarsdottir| noble title"
out = topkSample(input, model, tokenizer, num_samples=5)
out

[('princess', -1.093592),
 ('princess', -1.0935932),
 ('duke', -2.0669463),
 ('lady', -3.0623672),
 ('lord', -5.4861264)]

You can further load the list of entity aliases, then filter only those predictions which are valid entities 
then create a reverse mapping from alias -> integer id to get final predictions in required format.

However, loading these aliases in memory as a dictionary requires a lot of RAM + you need to download the aliases file 

The submitted validation/test results were obtained by sampling 300 times for each input, then applying above procedure, followed by filtering known entities. The final MRR can vary slightly due to this sampling nature (we found that although beam search gives deterministic output, the results are inferior to sampling large number of times).

In [20]:
# download valid.txt. you can also try same url with test.txt. however test does not contain the correct tails
!wget https://storage.googleapis.com/kgt5-wikikg90mv2/valid.txt

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
--2022-02-15 03:13:37--  https://storage.googleapis.com/kgt5-wikikg90mv2/valid.txt
Resolving storage.googleapis.com (storage.googleapis.com)... 142.250.179.80, 142.250.179.112, 216.58.204.112, ...
Connecting to storage.googleapis.com (storage.googleapis.com)|142.250.179.80|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1047144 (1023K) [text/plain]
Saving to: ‘valid.txt’


2022-02-15 03:13:38 (2.04 MB/s) - ‘valid.txt’ saved [1047144/1047144]



In [25]:
fname = 'valid.txt'
valid_lines = []
f = open(fname)
for line in f:
    valid_lines.append(line.rstrip())
f.close()

In [39]:
print(valid_lines[0])

untitled Spider-Man: Into the Spider-Verse sequel| director	Kemp Powers


In [36]:
from tqdm.auto import tqdm
# try unfiltered hits@k. this is approximation since model can sample same string multiple times
# you should run this on gpu if you want to evaluate on all points with 300 samples each
k = 1
count_at_k = 0
max_predictions = k
max_points = 1000
for line in tqdm(valid_lines[:max_points]):
    input, target = line.split('\t')
    model_output = topkSample(input, model, tokenizer, num_samples=max_predictions)
    prediction_strings = [x[0] for x in model_output]
    if target in prediction_strings:
        count_at_k += 1
print('Hits at {0} unfiltered: {1}'.format(k, count_at_k/max_points))

HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))


Hits at 1 unfiltered: 0.135
