# Setup

In [None]:
import gc
import torch
import pandas as pd
import flair.datasets as datasets

from tqdm import tqdm
from gensim.models import KeyedVectors
from flair.embeddings import BytePairEmbeddings
from flair.embeddings import DocumentPoolEmbeddings, SentenceTransformerDocumentEmbeddings
from wikipedia2vec import Wikipedia2Vec
from src.models.base import Base
from src.models.flair import BaseFlair
from src.models.wiki2vec import BaseWiki2Vec
from data.utils import getCandidates

In [None]:
EMB_PATH = "./embeddings/"
aida = datasets.NEL_ENGLISH_AIDA()
entity_desc = pd.read_csv('./data/test.csv')

In [None]:
mentions_tags = []
doc = 1162
for i in aida.test:
	context = i.to_plain_string()
	if context != '-DOCSTART-':
		mentions_tags += [[j.text, j.tag, context, doc] for j in i.get_spans()]
	else:
		doc += 1

In [None]:
def get_entity_desc(entity):
	try:
		return entity_desc[entity_desc['entity'] == entity]['description'].values[0]
	except:
		return ''

def get_candidates(mention, doc):
	df = getCandidates(doc, mention=mention)
	res = [i.split('/')[-1] for i in df['url'].values]
	cands = []
	for i in res:
		desc = get_entity_desc(i)
		if desc != '':
			cands.append([i, desc])
	return cands

# Testing base NED model using various types of embeddings

In [None]:
def test_batch(model, batch):
	preds = []
	for mention, tag, context, doc in tqdm(batch):
		cands = get_candidates(mention, doc)
		# Check the tag is a valid entity and is present in the candidate set
		if tag in [i[0] for i in cands]:
			pred_tag, conf = model.link(mention, context, candidates=cands)
			preds.append([mention, tag, pred_tag, conf])
	return preds


def batch(l, n):
	for i in range(0, len(l), n): 
		yield l[i:i + n]


def test(emb, model, docEmb=None, batchSize=0, saveAs=None):
	preds = []
	batches = list(batch(mentions_tags, batchSize)) if batchSize != 0 else [mentions_tags]
	for i in batches:
		with torch.no_grad():
			ned = model(docEmb(emb)) if docEmb is not None else model(emb)
			preds += test_batch(ned, i)
			clear_gpu_cache([ned, model, docEmb])

	res = pd.DataFrame(preds, columns=['mention', 'tag', 'predicted', 'confidence'])
	if saveAs is not None:
		res.to_csv(saveAs, index=False)
	if res.shape[0] > 0:
		acc = (res[res['tag'] == res['predicted']].shape[0]/res.shape[0])*100
	else:
		acc = 0
	print ("Accuracy: ", acc)


def clear_gpu_cache(objects):
	for i in objects: i = None
	gc.collect()
	torch.cuda.empty_cache()
	

### Word2Vec Google News 300d

In [None]:
word2vec = KeyedVectors.load(EMB_PATH + 'word2vec-google-news-300')
test(word2vec, Base, saveAs='./results/base_word2vec.csv')
# Cased : 52.37 %

### Glove Wiki-Gigaword 300d

In [None]:
glove = KeyedVectors.load(EMB_PATH + 'glove-wiki-gigaword-300')
test(glove,Base, saveAs='./results/base_glove.csv')
# Cased : 51.97 %

### Byte-Pair, 300d

In [None]:
byte_pair = BytePairEmbeddings('en', dim=300, syllables=200000)
bp_doc_emb = DocumentPoolEmbeddings([byte_pair], fine_tune_mode='nonlinear')
test(bp_doc_emb, BaseFlair, saveAs='./results/base_byte_pair.csv')
# Cased : 51.00 %

### FastText Wiki-News Subword 300d

In [None]:
ftext = KeyedVectors.load(EMB_PATH + 'fasttext-wiki-news-subwords-300')
test(ftext, Base, saveAs='./results/base_fasttext.csv')
# Cased : 39.44 %

### RoBERTa

In [None]:
test(
	'roberta-base',
	BaseFlair,
	docEmb=SentenceTransformerDocumentEmbeddings,
	batchSize=300, 
	saveAs='./results/base_roberta.csv')
# Cased : 34.92 %

### Wikipedia2vec

In [None]:
wiki2vec = Wikipedia2Vec.load(EMB_PATH + 'wiki2vec_w10_100d.pkl')
test(wiki2vec, BaseWiki2Vec, saveAs='./results/base_wiki2vec.csv')
# Cased : 62.06%