In [2]:
import warnings
warnings.filterwarnings('ignore')

# Setup

In [3]:
import flair.datasets as datasets
from data.utils import getCandidates, getDocument
from src.models.gbrt import GBRT
from wikipedia2vec import Wikipedia2Vec
from tqdm import tqdm

In [4]:
EMB_PATH = "./embeddings/"
wiki2vec = Wikipedia2Vec.load(EMB_PATH + 'wiki2vec_w10_100d.pkl')
model = GBRT(wiki2vec)

In [5]:
aida = datasets.NEL_ENGLISH_AIDA()
mentions_tags = []
doc = 1162
for i in aida.test:
	context = i.to_plain_string()
	if context != '-DOCSTART-':
		mentions_tags += [[j.text, j.tag, context, doc] for j in i.get_spans()]
	else:
		doc += 1

2021-11-24 15:17:20,137 Reading data from C:\Users\athar\.flair\datasets\nel_english_aida
2021-11-24 15:17:20,139 Train: C:\Users\athar\.flair\datasets\nel_english_aida\train
2021-11-24 15:17:20,141 Dev: C:\Users\athar\.flair\datasets\nel_english_aida\testa
2021-11-24 15:17:20,142 Test: C:\Users\athar\.flair\datasets\nel_english_aida\testb


In [6]:
def get_candidates(mention, doc):
	df = getCandidates(doc, mention=mention)
	return [i.split('/')[-1] for i in df['url'].values]

# Initial Tests

Use sentence of the mention as the context

In [7]:
preds = []
for mention, tag, context, docNum in tqdm(mentions_tags):
	docText = getDocument(docNum)
	candidates = get_candidates(mention, docNum)
	if tag in candidates:
		pred, conf = model.link(mention, context, candidates)
		preds.append([mention, tag, pred.replace(' ','_')])

accuracy = round((sum([1 for _, t, p in preds if t == p]) / len(preds)) * 100, 2)
print(f'Accuracy: {accuracy}%\nTotal test samples: {len(preds)}')

100%|██████████| 4497/4497 [00:49<00:00, 90.44it/s] 

Accuracy: 67.95%
Total test samples: 4250





Use the whole document as the context

In [8]:
preds = []
for mention, tag, context, docNum in tqdm(mentions_tags):
	docText = getDocument(docNum)
	candidates = get_candidates(mention, docNum)
	if tag in candidates:
		pred, conf = model.link(mention, docText, candidates)
		preds.append([mention, tag, pred.replace(' ','_')])

accuracy = round((sum([1 for _, t, p in preds if t == p]) / len(preds)) * 100, 2)
print(f'Accuracy: {accuracy}%\nTotal test samples: {len(preds)}')

100%|██████████| 4497/4497 [03:08<00:00, 23.88it/s]

Accuracy: 64.38%
Total test samples: 4250



