In [40]:
import sys
import pandas as pd
import numpy as np
from sklearn.base import BaseEstimator
from sklearn.metrics import accuracy_score
from swisscom import launch
from tqdm.notebook import tqdm

import warnings
warnings.filterwarnings("ignore")

In [41]:
sys.path.append("..")


data = pd.read_csv('./dataset.csv')

In [42]:
X = data['text']
y = data['loc']


In [43]:
class KeyPhraseExtractionModel(BaseEstimator):
    def __init__(self, embedding_model='roberta-large-nli-stsb-mean-tokens', beta=0.8, alias_threshold=0.7, n=1):
        self.X = None
        self.y = None
        self.beta = beta
        self.n = n
        self.alias_threshold = alias_threshold
        self.embedding_distributor = launch.load_local_embedding_distributor(embedding_model)
        self.pos_tagger = launch.load_local_corenlp_pos_tagger()


    def fit(self, X, y):
        self.X = X
        self.y = y


    def predict(self, X):
        return list(map(lambda text: launch.extract_keyphrases(self.embedding_distributor, self.pos_tagger, text, self.n, 'en', self.beta, self.alias_threshold)[0][0], X))

In [44]:
betas = list(np.arange(0.7, 1.0, 0.1))
models = [
    'roberta-large-nli-stsb-mean-tokens',
    'roberta-base-nli-stsb-mean-tokens',
    'distilbert-base-nli-stsb-mean-tokens',
    'distilroberta-base-paraphrase-v1',
    'xlm-r-distilroberta-base-paraphrase-v1',
    'distilroberta-base-msmarco-v2',
    'LaBSE',
    'facebook/bart-large-mnli',
]

In [45]:
def gen_params(models, betas):
    for model in models:
        for beta in betas:
            yield model, beta


In [46]:
model = KeyPhraseExtractionModel()
model.fit(X, y)

best_model = None
best_beta = None
max_score = -1
for embedding_model, beta in tqdm(list(gen_params(models, betas))):
    model = KeyPhraseExtractionModel(embedding_model, beta)
    score = accuracy_score(y_true=y, y_pred=model.predict(X))
    if score > max_score:
        max_score = score
        best_model = embedding_model
        best_beta = beta

print('Best results: ')
print(f' Model: {best_model}')
print(f' Beta: {best_beta}')
print(f' Accuracy: {int(max_score * 100)}%')

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=32.0), HTML(value='')))

Exception when trying to download https://sbert.net/models/facebook/bart-large-mnli.zip. Response 404
Some weights of the model checkpoint at facebook/bart-large-mnli were not used when initializing BartModel: ['model.encoder.version', 'model.decoder.version']
- This IS expected if you are initializing BartModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BartModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Exception when trying to download https://sbert.net/models/facebook/bart-large-mnli.zip. Response 404
Some weights of the model checkpoint at facebook/bart-large-mnli were not used when initializing BartModel: ['model.encoder.version', 'model.decoder.version']
- This IS expected if you


Best results: 
 Model: xlm-r-distilroberta-base-paraphrase-v1
 Beta: 0.7
 Accuracy: 74%


In [47]:
statuses = []
relevance = []
predicted = []

embedding_distributor = launch.load_local_embedding_distributor(best_model)
pos_tagger = launch.load_local_corenlp_pos_tagger()

for (text, y_true) in tqdm(zip(X, y)):
    kp = launch.extract_keyphrases(embedding_distributor, pos_tagger, text, 5, 'en')
    y_pred = kp[0][0]
    predicted.append(y_pred)
    if y_true == y_pred:
        statuses.append('SUCCESS')
        relevance.append(None)
    elif kp[2].count(y_true) > 0:
        statuses.append('SUCCESS IN ALIAS')
        relevance.append(zip(kp[2], kp[1]))
    else:
        statuses.append('FAILURE')
        relevance.append(None)


data = { 'Result': statuses,
         'Expected': y,
         'Actual': predicted,
         'Aliases': relevance
         }

df = pd.DataFrame(data, columns=['Result', 'Expected', 'Actual', 'Aliases'])
df.style.hide_index()

HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…




Result,Expected,Actual,Aliases
SUCCESS,beach,beach,
SUCCESS,beach,beach,
SUCCESS,beach,beach,
FAILURE,beach,sea,
SUCCESS,beach,beach,
SUCCESS,beach,beach,
SUCCESS,beach,beach,
SUCCESS,beach,beach,
SUCCESS,beach,beach,
SUCCESS,beach,beach,
