In [1]:
!pip install gensim



In [2]:
from __future__ import absolute_import, division, print_function

import argparse
import glob
import logging
import os
import pickle
import random
import re
import shutil

import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, SequentialSampler, RandomSampler
from torch.utils.data.distributed import DistributedSampler

try:
    from torch.utils.tensorboard import SummaryWriter
except:
    from tensorboardX import SummaryWriter

from tqdm import trange
from tqdm.autonotebook import tqdm

from data_loader.hybrid_data_loaders import *
from data_loader.header_data_loaders import *
from data_loader.CT_Wiki_data_loaders import *
from data_loader.RE_data_loaders import *
from data_loader.EL_data_loaders import *
from model.configuration import TableConfig
from model.model import HybridTableMaskedLM, HybridTableCER, TableHeaderRanking, HybridTableCT,HybridTableEL,HybridTableRE,BertRE
from model.transformers import BertConfig,BertTokenizer, WEIGHTS_NAME, AdamW, get_linear_schedule_with_warmup
from utils.util import *
from baselines.row_population.metric import average_precision,ndcg_at_k
from baselines.cell_filling.cell_filling import *
from model import metric

In [3]:
logger = logging.getLogger(__name__)

MODEL_CLASSES = {
    'CER': (TableConfig, HybridTableCER, BertTokenizer),
    'CF' : (TableConfig, HybridTableMaskedLM, BertTokenizer),
    'HR': (TableConfig, TableHeaderRanking, BertTokenizer),
    'CT': (TableConfig, HybridTableCT, BertTokenizer),
    'EL': (TableConfig, HybridTableEL, BertTokenizer),
    'RE': (TableConfig, HybridTableRE, BertTokenizer),
    'REBERT': (BertConfig, BertRE, BertTokenizer)
}

In [4]:
# set data directory, this will be used to load test data
data_dir = 'data/'

In [5]:
config_name = "configs/table-base-config_v2.json"
device = torch.device('cuda')
# load entity vocab from entity_vocab.txt
entity_vocab = load_entity_vocab(data_dir, ignore_bad_title=True, min_ent_count=2)
entity_wikid2id = {entity_vocab[x]['wiki_id']:x for x in entity_vocab}

total number of entity: 926135
remove because of empty title: 14206
remove because count<2: 847401


In [6]:
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

# Table of Contents
This notebook shows examples of how to using the model components and running evaluation of different tasks.
* [Pretrained and Cell Filling](#cf)
* [Entity Linking](#el)
* [Column Type Classification](#ct)
* [Relation Extraction](#re)

<a class="anchor" id="cf"></a>
# Pretrained and CF
Here we show how to use the pretrained model to get contextualized representation for a given input table. 

We use the cell filling task for demonstration as it does not need task-specific finetuning.

In [None]:
config_class, model_class, _ = MODEL_CLASSES['CF']
config = config_class.from_pretrained(config_name)
config.output_attentions = True

# For CF, we use the base HybridTableMaskedLM, and directly load the pretrained checkpoint
checkpoint = "checkpoint/pretrained/"
model = model_class(config, is_simple=True)
checkpoint = torch.load(os.path.join(checkpoint, 'pytorch_model.bin'))
model.load_state_dict(checkpoint)
model.to(device)
model.eval()
# load the module for cell filling baselines
CF = cell_filling(data_dir)

In [None]:
with open(os.path.join(data_dir,"CF_test_data.json"), 'r') as f:
    dev_data = json.load(f)[:10]
print('example for cell filling')
display(dev_data[0])
# the dataset here is the dataloader for pretraining. We use it to pass the config to construct the cell filling example
dataset = WikiHybridTableDataset(data_dir,entity_vocab,max_cell=100, max_input_tok=350, max_input_ent=150, src="dev", max_length = [50, 10, 10], force_new=False, tokenizer = None, mode=0)
print('example of pretraining data')
with open(os.path.join(data_dir, 'dev_tables.jsonl'), 'r') as f:
    for line in f:
        example = json.loads(line.strip())
        break
display(example)

In [None]:
# This is an example of converting an arbitrary table to input
# Here we show an example for cell filling task
# The input entites are entities in the subject column, we append [ENT_MASK] and use its representation to match with the candidate entities
def CF_build_input(pgEnt, pgTitle, secTitle, caption, headers, core_entities, core_entities_text, entity_cand, config):
    tokenized_pgTitle = config.tokenizer.encode(pgTitle, max_length=config.max_title_length, add_special_tokens=False)
    tokenized_meta = tokenized_pgTitle+\
                    config.tokenizer.encode(secTitle, max_length=config.max_title_length, add_special_tokens=False)
    if caption != secTitle:
        tokenized_meta += config.tokenizer.encode(caption, max_length=config.max_title_length, add_special_tokens=False)
    tokenized_headers = [config.tokenizer.encode(header, max_length=config.max_header_length, add_special_tokens=False) for header in headers]
    input_tok = []
    input_tok_pos = []
    input_tok_type = []
    tokenized_meta_length = len(tokenized_meta)
    input_tok += tokenized_meta
    input_tok_pos += list(range(tokenized_meta_length))
    input_tok_type += [0]*tokenized_meta_length
    header_span = []
    for tokenized_header in tokenized_headers:
        tokenized_header_length = len(tokenized_header)
        header_span.append([len(input_tok), len(input_tok)+tokenized_header_length])
        input_tok += tokenized_header
        input_tok_pos += list(range(tokenized_header_length))
        input_tok_type += [1]*tokenized_header_length

    input_ent = [config.entity_wikid2id[pgEnt] if pgEnt!=-1 else 0]
    input_ent_text = [tokenized_pgTitle[:config.max_cell_length]]
    input_ent_type = [2]
    
    # core entities in the subject column
    input_ent += [config.entity_wikid2id[entity] for entity in core_entities]
    input_ent_text += [config.tokenizer.encode(entity_text, max_length=config.max_cell_length, add_special_tokens=False) if len(entity_text)!=0 else [] for entity_text in core_entities_text]
    input_ent_type += [3]*len(core_entities)
    
    # append [ent_mask]
    input_ent += [config.entity_wikid2id['[ENT_MASK]']]*len(core_entities)
    input_ent_text += [[]]*len(core_entities)
    input_ent_type += [4]*len(core_entities)

    input_ent_cell_length = [len(x) if len(x)!=0 else 1 for x in input_ent_text]
    max_cell_length = max(input_ent_cell_length)
    input_ent_text_padded = np.zeros([len(input_ent_text), max_cell_length], dtype=int)
    for i,x in enumerate(input_ent_text):
        input_ent_text_padded[i, :len(x)] = x
    assert len(input_ent) == 1+2*len(core_entities)

    input_tok_mask = np.ones([1, len(input_tok), len(input_tok)+len(input_ent)], dtype=int)
    input_tok_mask[0, header_span[0][0]:header_span[0][1], len(input_tok)+1+len(core_entities):] = 0
    input_tok_mask[0, header_span[1][0]:header_span[1][1], len(input_tok)+1:len(input_tok)+1+len(core_entities)] = 0
    input_tok_mask[0, :, len(input_tok)+1+len(core_entities):] = 0

    # build the mask for entities
    input_ent_mask = np.ones([1, len(input_ent), len(input_tok)+len(input_ent)], dtype=int)
    input_ent_mask[0, 1:1+len(core_entities), header_span[1][0]:header_span[1][1]] = 0
    input_ent_mask[0, 1:1+len(core_entities), len(input_tok)+1+len(core_entities):] = np.eye(len(core_entities), dtype=int)
    input_ent_mask[0, 1+len(core_entities):, header_span[0][0]:header_span[0][1]] = 0
    input_ent_mask[0, 1+len(core_entities):, len(input_tok)+1:len(input_tok)+1+len(core_entities)] = np.eye(len(core_entities), dtype=int)
    input_ent_mask[0, 1+len(core_entities):, len(input_tok)+1+len(core_entities):] = np.eye(len(core_entities), dtype=int)

    input_tok_mask = torch.LongTensor(input_tok_mask)
    input_ent_mask = torch.LongTensor(input_ent_mask)

    input_tok = torch.LongTensor([input_tok])
    input_tok_type = torch.LongTensor([input_tok_type])
    input_tok_pos = torch.LongTensor([input_tok_pos])
    
    input_ent = torch.LongTensor([input_ent])
    input_ent_text = torch.LongTensor([input_ent_text_padded])
    input_ent_cell_length = torch.LongTensor([input_ent_cell_length])
    input_ent_type = torch.LongTensor([input_ent_type])

    input_ent_mask_type = torch.zeros_like(input_ent)
    input_ent_mask_type[:,1+len(core_entities):] = config.entity_wikid2id['[ENT_MASK]']

    candidate_entity_set = [config.entity_wikid2id[entity] for entity in entity_cand]
    candidate_entity_set = torch.LongTensor([candidate_entity_set])
    

    return input_tok, input_tok_type, input_tok_pos, input_tok_mask,\
            input_ent, input_ent_text, input_ent_cell_length, input_ent_type, input_ent_mask_type, input_ent_mask, candidate_entity_set

In [None]:
results = []
for table_id,pgEnt,pgTitle,secTitle,caption,(h1, h2),data_sample in tqdm(dev_data):
    result = []
    while len(data_sample)!=0:
        core_entities = []
        core_entities_text = []
        target_entities = []
        all_entity_cand = set()
        entity_cand = []
        for (core_e, core_e_text), target_e in data_sample[:100]:
            assert target_e in entity_wikid2id
            core_entities.append(core_e)
            core_entities_text.append(core_e_text)
            target_entities.append(target_e)
            cands = CF.get_cand_row(core_e, h2)
            cands = {key:value for key,value in cands.items() if key in entity_wikid2id}
            entity_cand.append(cands)
            all_entity_cand |= set(cands.keys()) 
        all_entity_cand = list(all_entity_cand)
        input_tok, input_tok_type, input_tok_pos, input_tok_mask,\
            input_ent, input_ent_text, input_ent_text_length, input_ent_type, input_ent_mask_type, input_ent_mask, \
            candidate_entity_set = CF_build_input(pgEnt, pgTitle, secTitle, caption, [h1, h2], core_entities, core_entities_text, all_entity_cand, dataset)
        input_tok = input_tok.to(device)
        input_tok_type = input_tok_type.to(device)
        input_tok_pos = input_tok_pos.to(device)
        input_tok_mask = input_tok_mask.to(device)
        input_ent_text = input_ent_text.to(device)
        input_ent_text_length = input_ent_text_length.to(device)
        input_ent = input_ent.to(device)
        input_ent_type = input_ent_type.to(device)
        input_ent_mask_type = input_ent_mask_type.to(device)
        input_ent_mask = input_ent_mask.to(device)
        candidate_entity_set = candidate_entity_set.to(device)
        with torch.no_grad():
            tok_outputs, ent_outputs = model(input_tok, input_tok_type, input_tok_pos, input_tok_mask,
                            input_ent_text, input_ent_text_length, input_ent_mask_type,
                            input_ent, input_ent_type, input_ent_mask, candidate_entity_set)
            num_sample = len(target_entities)
            ent_prediction_scores = ent_outputs[0][0,num_sample+1:].tolist()
        for i, target_e in enumerate(target_entities):
            predictions = ent_prediction_scores[i]
            if len(entity_cand[i]) == 0:
                result.append([target_e, entity_cand[i], [], []])
            else:
                tmp_cand_scores = []
                for j, cand_e in enumerate(all_entity_cand):
                    if cand_e in entity_cand[i]:
                        tmp_cand_scores.append([cand_e, predictions[j]])
                sorted_cand_scores =  sorted(tmp_cand_scores, key=lambda z:z[1], reverse=True)
                sorted_cands = [z[0] for z in sorted_cand_scores]
                # use H2H as baseline
                base_sorted_cands = CF.rank_cand_h2h(h2, entity_cand[i])
                result.append([target_e, entity_cand[i], sorted_cands, base_sorted_cands])
        data_sample = data_sample[100:]
    results.append({
        'pgTitle': pgTitle,
        'secTitle': secTitle,
        'caption': caption,
        'headers': [h1, h2],
        'result': result
    })

In [None]:
print('tok(metadata) outputs', len(tok_outputs))
print('tok prediction logits: [batch_size, num_toks, vocab_size]\n', tok_outputs[0].shape)
print('tok hidden states: [batch_size, num_toks, hidden_size]\n', tok_outputs[1].shape)
print('tok attention: n_layers*[batch_size, num_attention_headers, num_toks, num_toks+num_ents]\n', tok_outputs[2][0].shape)
print('entity(cell) outputs', len(ent_outputs))
print('ent prediction logits: [batch_size, num_ents, candidate_size]\n', ent_outputs[0].shape)
print('ent hidden states: [batch_size, num_ents, hidden_size]\n', ent_outputs[1].shape)
print('ent attention: n_layers*[batch_size, num_attention_headers, num_ents, num_toks+num_ents]\n', ent_outputs[2][0].shape)

In [None]:
def get_precision(result):
    recall = 0
    precision_neural = [0, 0, 0, 0]
    precision_base = [0, 0, 0, 0]
    for target_e, cand, p_neural, p_base in result:
        if target_e in cand:
            recall += 1
            if target_e == p_neural[0]:
                precision_neural[0] += 1
            if target_e == p_base[0]:
                precision_base[0] += 1
            if target_e in p_neural[:3]:
                precision_neural[1] += 1
            if target_e in p_neural[:5]:
                precision_neural[2] += 1
            if target_e in p_neural[:10]:
                precision_neural[3] += 1
            if target_e in p_base[:3]:
                precision_base[1] += 1
            if target_e in p_base[:5]:
                precision_base[2] += 1
            if target_e in p_base[:10]:
                precision_base[3] += 1
    if recall != 0:
        return recall/len(result), [z/recall for z in precision_neural], [z/recall for z in precision_base]
    else:
        return 0, [0 for z in precision_neural], [0 for z in precision_base]

In [None]:
final_results = [get_precision(x['result']) for x in results]
print('recall', np.mean([x[0] for x in final_results]))
print('neural')
print('p@1', np.mean([x[1][0] for x in final_results if x[0]!=0]))
print('p@3', np.mean([x[1][1] for x in final_results if x[0]!=0]))
print('p@5', np.mean([x[1][2] for x in final_results if x[0]!=0]))
print('p@10', np.mean([x[1][3] for x in final_results if x[0]!=0]))
print('base')
print('p@1', np.mean([x[2][0] for x in final_results if x[0]!=0]))
print('p@3', np.mean([x[2][1] for x in final_results if x[0]!=0]))
print('p@5', np.mean([x[2][2] for x in final_results if x[0]!=0]))
print('p@10', np.mean([x[2][3] for x in final_results if x[0]!=0]))

In [None]:
final_results = [get_precision(x['result']) for x in results]
print('recall', np.mean([x[0] for x in final_results]))
print('neural')
print('p@1', np.mean([x[1][0] for x in final_results if x[0]!=0]))
print('p@3', np.mean([x[1][1] for x in final_results if x[0]!=0]))
print('p@5', np.mean([x[1][2] for x in final_results if x[0]!=0]))
print('p@10', np.mean([x[1][3] for x in final_results if x[0]!=0]))
print('base')
print('p@1', np.mean([x[2][0] for x in final_results if x[0]!=0]))
print('p@3', np.mean([x[2][1] for x in final_results if x[0]!=0]))
print('p@5', np.mean([x[2][2] for x in final_results if x[0]!=0]))
print('p@10', np.mean([x[2][3] for x in final_results if x[0]!=0]))

In [None]:
print('neural')
print('p@1', np.mean([x[1][0] for i,x in enumerate(final_results) if x[0]!=0 and 'team' not in results[i]['headers'][1]]))
print('p@3', np.mean([x[1][1] for i,x in enumerate(final_results) if x[0]!=0 and 'team' not in results[i]['headers'][1]]))
print('p@5', np.mean([x[1][2] for i,x in enumerate(final_results) if x[0]!=0 and 'team' not in results[i]['headers'][1]]))
print('p@10', np.mean([x[1][3] for i,x in enumerate(final_results) if x[0]!=0 and 'team' not in results[i]['headers'][1]]))
print('base')
print('p@1', np.mean([x[2][0] for i,x in enumerate(final_results) if x[0]!=0 and 'team' not in results[i]['headers'][1]]))
print('p@3', np.mean([x[2][1] for i,x in enumerate(final_results) if x[0]!=0 and 'team' not in results[i]['headers'][1]]))
print('p@5', np.mean([x[2][2] for i,x in enumerate(final_results) if x[0]!=0 and 'team' not in results[i]['headers'][1]]))
print('p@10', np.mean([x[2][3] for i,x in enumerate(final_results) if x[0]!=0 and 'team' not in results[i]['headers'][1]]))

In [None]:
print('neural')
print('p@1', np.mean([x[1][0] for i,x in enumerate(final_results) if x[0]!=0 and 'team' not in results[i]['headers'][1]]))
print('p@3', np.mean([x[1][1] for i,x in enumerate(final_results) if x[0]!=0 and 'team' not in results[i]['headers'][1]]))
print('p@5', np.mean([x[1][2] for i,x in enumerate(final_results) if x[0]!=0 and 'team' not in results[i]['headers'][1]]))
print('p@10', np.mean([x[1][3] for i,x in enumerate(final_results) if x[0]!=0 and 'team' not in results[i]['headers'][1]]))
print('base')
print('p@1', np.mean([x[2][0] for i,x in enumerate(final_results) if x[0]!=0 and 'team' not in results[i]['headers'][1]]))
print('p@3', np.mean([x[2][1] for i,x in enumerate(final_results) if x[0]!=0 and 'team' not in results[i]['headers'][1]]))
print('p@5', np.mean([x[2][2] for i,x in enumerate(final_results) if x[0]!=0 and 'team' not in results[i]['headers'][1]]))
print('p@10', np.mean([x[2][3] for i,x in enumerate(final_results) if x[0]!=0 and 'team' not in results[i]['headers'][1]]))

<a class="anchor" id="el"></a>
# EL
Evaluate Entity Linking

In [7]:
#load dbpedia types from depedia_type_vocab.txt (now just type vocab)
type_vocab = load_type_vocab(data_dir)
config_class, model_class, _ = MODEL_CLASSES['EL']
config = config_class.from_pretrained(config_name)
config.ent_type_vocab_size = len(type_vocab)
config.mode=2

In [8]:
with open(os.path.join(data_dir, 'test_own.table_entity_linking.json'), 'r') as f:
    example = json.load(f)[0]
display(example)

['377509-5',
 '2000 Rugby League World Cup',
 'Final',
 '',
 ['australia', 'position', 'new zealand'],
 [[[9, 0], 'Robbie Kearns'],
  [[12, 0], 'Scott Hill'],
  [[14, 0], 'Nathan Hindmarsh'],
  [[7, 2], 'Craig Smith'],
  [[11, 2], 'Stephen Kearney'],
  [[1, 0], 'Mat Rogers'],
  [[13, 0], 'Trent Barrett'],
  [[10, 0], 'Gorden Tallis'],
  [[5, 2], 'Henry Paul'],
  [[16, 0], 'Jason Stevens'],
  [[4, 0], 'Wendell Sailor'],
  [[17, 0], 'Chris Anderson'],
  [[15, 2], 'Nathan Cayless'],
  [[8, 2], 'Richard Swain'],
  [[9, 2], 'Quentin Pongia'],
  [[6, 2], 'Stacey Jones'],
  [[2, 2], 'Tonie Carroll'],
  [[11, 0], 'Bryan Fletcher'],
  [[0, 0], 'Darren Lockyer'],
  [[4, 2], 'Lesley Vainikolo'],
  [[0, 2], 'Richie Barnett'],
  [[15, 0], 'Darren Britt'],
  [[8, 0], 'Andrew Johns'],
  [[14, 2], 'Joe Vagana'],
  [[12, 2], 'Ruben Wiki'],
  [[7, 0], 'Shane Webcke'],
  [[5, 0], 'Brad Fittler'],
  [[13, 2], 'Robbie Paul'],
  [[10, 2], 'Matt Rua'],
  [[2, 0], 'Adam MacDougall'],
  [[16, 2], 'Logan Swann'

In [9]:
# load test data from [dataset].table_entity_linking.json
test_dataset = ELDataset(data_dir, type_vocab, max_input_tok=500, src="test_own", max_length = [50, 10, 10, 100], force_new=False, tokenizer = None)

try loading preprocessed data from data/procressed_EL/test_own.pickle


In [10]:
model = model_class(config, is_simple=True)
# load the checkpoint based on mode
checkpoint = torch.load(f"checkpoint/entity_linking/{config.mode}/pytorch_model.bin")
#checkpoint = torch.load(f"checkpoint/pretrained/pytorch_model.bin")
model.load_state_dict(checkpoint)
model.to(device)
model.eval()

RuntimeError: Error(s) in loading state_dict for HybridTableEL:
	size mismatch for table.embeddings.ent_embeddings.weight: copying a param with shape torch.Size([926135, 312]) from checkpoint, the shape in current model is torch.Size([404, 312]).
	size mismatch for cand_embeddings.ent_type_embeddings.weight: copying a param with shape torch.Size([404, 312]) from checkpoint, the shape in current model is torch.Size([255, 312]).

In [None]:
test_batch_size = 10
test_sampler = SequentialSampler(test_dataset)
test_dataloader = ELLoader(test_dataset, sampler=test_sampler, batch_size=test_batch_size, is_train=False)

# Eval!
print("Num examples = %d"%len(test_dataset))
print("Batch size = %d"%test_batch_size)
test_loss = 0.0
test_acc = 0.0
nb_test_steps = 0
test_results = []

for batch in tqdm(test_dataloader, desc="Evaluating"):
    table_id, input_tok, input_tok_type, input_tok_pos, input_tok_mask, \
            input_ent_text, input_ent_text_length, input_ent_type, input_ent_mask, \
            cand_name, cand_name_length,cand_description, cand_description_length,cand_type, cand_type_length, cand_mask, \
            labels,entities_index = batch
    input_tok = input_tok.to(device)
    input_tok_type = input_tok_type.to(device)
    input_tok_pos = input_tok_pos.to(device)
    input_tok_mask = input_tok_mask.to(device)
    input_ent_text = input_ent_text.to(device)
    input_ent_text_length = input_ent_text_length.to(device)
    input_ent_type = input_ent_type.to(device)
    input_ent_mask = input_ent_mask.to(device)
    cand_name = cand_name.to(device)
    cand_name_length = cand_name_length.to(device)
    cand_description = cand_description.to(device)
    cand_description_length = cand_description_length.to(device)
    cand_type = cand_type.to(device)
    cand_type_length = cand_type_length.to(device)
    cand_mask = cand_mask.to(device)
    labels = labels.to(device)
    
    if config.mode == 1:
        cand_description = None
        cand_description_length = None
    elif config.mode == 2:
        cand_type = None
        cand_type_length = None
    elif config.mode != 0:
        raise Exception
    
    with torch.no_grad():
        outputs = model(input_tok, input_tok_type, input_tok_pos, input_tok_mask,\
            input_ent_text, input_ent_text_length, input_ent_type, input_ent_mask, \
            cand_name, cand_name_length,cand_description, cand_description_length,cand_type, cand_type_length, cand_mask, \
            labels)
        loss = outputs[0]
        prediction_scores = outputs[1]
        predict_index = torch.argsort(prediction_scores.view(input_ent_text.size(0),input_ent_text.size(1)-1,-1),descending=True)
        sorted_scores = (torch.gather(prediction_scores.view(input_ent_text.size(0),input_ent_text.size(1)-1,-1),-1,predict_index)).tolist()
        predict_index = predict_index.tolist()
        acc = metric.accuracy(prediction_scores, labels.view(-1),ignore_index=-1)
        cand_length = cand_mask.sum(1).tolist()
        ent_length = (labels!=-1).sum(1).tolist()
        for i,t_id in enumerate(table_id):
            test_results.append([t_id,entities_index[i],\
                                 [x[:cand_length[i]] for x in predict_index[i][:ent_length[i]]],\
                                 [x[:cand_length[i]] for x in sorted_scores[i][:ent_length[i]]],\
                                ])
        test_loss += loss.mean().item()
        test_acc += acc.item()
    nb_test_steps += 1

test_loss = test_loss / nb_test_steps
test_acc = test_acc / nb_test_steps

result = {
    "eval_loss": test_loss,
    "eval_acc": test_acc,
}
for key in sorted(result.keys()):
    print("%s = %s"%(key, str(result[key])))

In [None]:
# we dump the predictions in seperate file an use another script for official evaluation.
# The reason is that our entity linking is based on wikidata lookup. In certain cases, the candidates
# do not contain the target entity, such test example is still considered for metric calculation.
# However, since there is nothing to rank we do not pass thoses examples here. So the test examples here
# is incomplete
with open(os.path.join(data_dir,"test_own_entity_linking_results_0.pkl"),"wb") as f:
    pickle.dump(test_results, f)