In [2]:
import gensim
w2v_model = gensim.models.KeyedVectors.load_word2vec_format(r'''araneum_upos_skipgram_300_2_2018.vec''')

In [41]:
from sklearn.externals import joblib
import pandas as pd
from nltk.stem.snowball import SnowballStemmer

In [42]:
def levenshtein_distance(a, b):
    "Calculates the Levenshtein distance between a and b."
    n, m = len(a), len(b)
    if n > m:
        # Make sure n <= m, to use O(min(n,m)) space
        a, b = b, a
        n, m = m, n
 
    current_row = [(0, i, 0) for i in range(n + 1)]  # Keep current and previous row, not entire matrix
    for i in range(1, m + 1):
        previous_row, current_row = current_row, [(i, 0, 0)] + [(0, 0, 0)] * n
        for j in range(1, n + 1):
            add, delete, change = previous_row[j], current_row[j - 1], previous_row[j - 1]
            add = (add[0] + 1, add[1], add[2])
            delete = (delete[0], delete[1] + 1, delete[2])
            if a[j - 1] != b[i - 1]:
                change = (change[0], change[1], change[2] + 1)
            current_row[j] = min(add, delete, change, key=lambda x: sum(x))
 
    return current_row[n]

In [43]:
clf = joblib.load("rf_classifier.pkl")
stemmer = SnowballStemmer("russian")

def get_same_stem_russian(word1: str, word2: str) -> dict:
    a = stemmer.stem(word1)
    b = stemmer.stem(word2)

    maxLen = max(len(a), len(b))
    minLen = min(len(a), len(b))

    add, delete, change = levenshtein_distance(a, b)
    df = pd.DataFrame(data=[(maxLen, minLen, add, delete, change)],
                      columns=['maxLen', 'minLen', 'add', 'delete', 'change'])

    return {
        'predict': int(clf.predict(df)[0]),
        'predict_proba': float(clf.predict_proba(df)[:, 1][0])
    }

In [69]:
def assoc_list(word: str, topn=5) -> list:
    arr = w2v_model.most_similar(positive=[word], topn=topn+20)
    res = []
    for i in arr:
        flag = True
        i_buf = i[0][:i[0].find("_")]
        for j in [(word, 1)] + res:
            j_buf = j[0][:j[0].find("_")]
            if get_same_stem_russian(i_buf, j_buf)['predict'] == 1:
                flag = False
        if flag:
            res.append(i)
    return res[:topn]

In [73]:
assoc_list("ракета_NOUN")

[('боеголовка_NOUN', 0.7894258499145508),
 ('баллистический_ADJ', 0.756476879119873),
 ('противокорабельный_ADJ', 0.7446105480194092),
 ('брпл_X', 0.7339468002319336),
 ('боезаряд_NOUN', 0.6867998242378235)]