In [None]:
from nltk.tokenize import sent_tokenize
from nltk.tokenize import word_tokenize
from nltk.corpus import stopwords
from gensim.models.word2vec import Word2Vec
from multiprocessing import cpu_count
import os
import re
import nltk
import torch
import pandas as pd
from tqdm import tqdm
from collections import defaultdict
nltk.download('punkt')
nltk.download('stopwords')

In [None]:
class Def_handler():
    def __init__(self, a_tops=10):
        self.tops = a_tops
        self.r = re.compile("[а-яА-Я-]+")
        self.stops = stopwords.words("russian")

    def set_model(self, a_model):
        self.model = a_model
        self.vocab = self.model.wv.vocab
        
    def phrase_to_keys(self, prefix, phrase):
        # prepare definition words for search
        words = [word.lower() for word in word_tokenize(phrase) if word.isalpha()]
        words = [w for w in filter(self.r.match, words)]
        words = [word for word in words if word not in self.stops]

        found = defaultdict(int)
        for word in words:
            if word in self.vocab:
                similar = [i[0] for i in self.model.wv.most_similar(word, topn=self.tops)]
                for elem in similar:
                    if prefix == elem[:len(prefix)]:
                        found[elem] += 1
                        
        ranged_list = sorted(found.items(), key=lambda item: item[1], reverse=True)
        res = [elem[0] for elem in ranged_list]
        
        return res

In [None]:
class Naive_metric():
    def __init__(self):
        self.n_success = 0
        self.n_test = 0

    def update(self, word, res):
        self.n_test += 1
        if word in res:
            self.n_success += 1

    def score(self):
        return self.n_success / self.n_test

In [None]:
def test_model(model_name, test_set_name, def_handler, metric):
    '''
    example:

    test_model('wiki0_cbow_model', 'rus_test_set.csv', Def_handler(), Naive_metric())
    '''
    # загрузка word2vec модели
    model = torch.load(model_name)
    vocab = model.wv.vocab
    # настройка обработчика
    def_handler.set_model(model)

    # тестовый сет - слово и его фраза-определение
    test_set = pd.read_csv(test_set_name)
    # n_defs = test_set.shape[0]
    n_defs = 100

    not_in_dict = 0

    for i in tqdm(range(n_defs)):
        word = test_set.iloc[i]['word']
        if word in vocab:
            def_text = test_set.iloc[i]['defs']
            prefix_size = 0

            # пробуем угадать слово по определению и префиксу
            # результат - ранжированный список слов
            res = def_handler.phrase_to_keys(word[:prefix_size], def_text)
            # обновляем метрику
            metric.update(word, res)
        else:
            not_in_dict += 1

    print('Test score: ', metric.score())
    print('Not in dictionary: ', not_in_dict)