In [1]:
import fasttext
import os
import pandas as pd
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

In [2]:
df = pd.read_csv('../wbst/min1000_eng.wbst.csv', encoding='utf8')

In [3]:
def get_true(df):
    for row in df.iterrows():
        for i in range(4):
            if row[1]['odpowiedź'] == row[1]['termin ' + str(i+1)]:
                yield i

def make_vectorizer(model):
    def vectorize(row):
        return model.get_word_vector(row['pytanie']), model.get_word_vector(row['termin 1']), model.get_word_vector(row['termin 2']), model.get_word_vector(row['termin 3']), model.get_word_vector(row['termin 4'])
    return vectorize

l = len(df)
y_true = np.zeros(l)
for i in range(1, 4):
    mask = df['odpowiedź'] == df['termin ' + str(i+1)]
    y_true[mask] = np.full_like(y_true, i)[mask]

In [4]:
metrics = pd.DataFrame(columns=['model', 'accuracy', 'precision', 'recall', 'f1'])
for model_path in os.listdir('../dist_models'):
    model_path = '../dist_models/' + model_path
    vectorizer = fasttext.FastText.load_model(model_path)
    query_vec = np.stack(df['pytanie'].apply(vectorizer.get_word_vector))
    ans_vec = [np.stack(df['termin ' + str(i+1)].apply(vectorizer.get_word_vector)) for i in range(4)]
    sims = np.stack([np.diag(cosine_similarity(query_vec, ans_i_vec)) for ans_i_vec in ans_vec])
    y_pred = np.argmax(sims, axis=0)
    m = {
        'model': model_path[15:-4],
        'accuracy': accuracy_score(y_true, y_pred),
        'precision': precision_score(y_true, y_pred, average='weighted'),
        'recall': recall_score(y_true, y_pred, average='weighted'),
        'f1': f1_score(y_true, y_pred, average='weighted')
    }
    metrics = metrics.append(m, ignore_index=True)



In [5]:
metrics

Unnamed: 0,model,accuracy,precision,recall,f1
0,full_m1,0.576437,0.576853,0.576437,0.57646
1,full_m2,0.594791,0.595828,0.594791,0.594966
2,sample_m1,0.541598,0.541853,0.541598,0.541654
3,sample_m2,0.545884,0.546567,0.545884,0.545972
