In [1]:
# text preprocessing
from pymorphy2 import MorphAnalyzer
import string
from stop_words import get_stop_words
import re
import json

# chat-bot creation
from gensim.models import Word2Vec, FastText
import gensim
import annoy

from tqdm.notebook import tqdm
import numpy as np
import pandas as pd

In [2]:
PREP_ANS_PATH = 'prepared_answers.txt'
INIT_ANS_PATH = 'Otvety.txt'

In [6]:
%%time
# ячейка не обязательна если у вас уже создан prepared_answers.txt выполнять не надо
assert True

#Small preprocess of the answers

question = None
written = False

"""Данные взяты из подготовленного датасета ответов от mail.ru"""
with open(PREP_ANS_PATH, "w", encoding='utf-8') as fout:
    with open(INIT_ANS_PATH, "r", encoding='utf-8') as fin:
        for line in tqdm(fin):
            if line.startswith("---"):
                written = False
                continue
            if not written and question is not None:
                fout.write(question.replace("\t", " ").strip() + "\t" + line.replace("\t", " "))
                written = True
                question = None
                continue
            if not written:
                question = line.strip()
                continue

0it [00:00, ?it/s]

CPU times: total: 20.5 s
Wall time: 37.2 s


Функция для предобработки текста

In [7]:
morpher = MorphAnalyzer()
sw = set(get_stop_words("ru"))
exclude = set(string.punctuation)

In [8]:
def preprocess_txt(text):
    text = re.sub(r'<.*?>', ' ', text)
    text = re.sub(r'.*:', ' ', text)
    text = "".join(i for i in text.strip() if i not in exclude).split()
    text = [morpher.parse(i.lower())[0].normal_form for i in text]
    text = [i for i in text if i not in sw and i != ""]
    return text

In [10]:
%%time
# Preprocess for models fitting
sentences = []
c = 0
with open(INIT_ANS_PATH, "r", encoding='utf-8') as fin:
    for line in tqdm(fin):
        spls = preprocess_txt(line)
        sentences.append(spls)
        c += 1
        # if c > 600000:
        #     break

0it [00:00, ?it/s]

CPU times: total: 25min 37s
Wall time: 26min 40s


In [14]:
%%time
sentences = [i for i in sentences if len(i) > 2]
modelW2V = Word2Vec(sentences=sentences, vector_size=300, window=5, min_count=3, workers=8)
modelFT = FastText(sentences=sentences, vector_size=300, min_count=3, window=5, workers=8)
# Сохранение моделей
modelFT.save("ft_model")
modelW2V.save("w2v_model")

CPU times: total: 9min 30s
Wall time: 3min 25s


In [16]:
%%time

w2v_index = annoy.AnnoyIndex(300 ,'angular')
ft_index = annoy.AnnoyIndex(300 ,'angular')

index_map = {}
counter = 0

with open(PREP_ANS_PATH, "r", encoding='utf-8') as f:
    for line in tqdm(f):
        n_w2v = 0
        n_ft = 0
        spls = line.split("\t")
        index_map[counter] = spls[1]
        question = preprocess_txt(spls[0])
        
        vector_w2v = np.zeros(300)
        vector_ft = np.zeros(300)
        for word in question:
            if word in modelW2V.wv:
                vector_w2v += modelW2V.wv[word]
                n_w2v += 1
            if word in modelFT.wv:
                vector_ft += modelFT.wv[word]
                n_ft += 1
        if n_w2v > 0:
            vector_w2v = vector_w2v / n_w2v
        if n_ft > 0:
            vector_ft = vector_ft / n_ft
        w2v_index.add_item(counter, vector_w2v)
        ft_index.add_item(counter, vector_ft)
            
        counter += 1

        # if counter > 600000: 
        #     break

w2v_index.build(10)
ft_index.build(10)

0it [00:00, ?it/s]

CPU times: total: 1h 7min 3s
Wall time: 1h 7min 34s


True

In [None]:
# Сохранение index_map
with open('index_map.txt', 'w') as outfile:
    json.dump(index_map, outfile)

In [18]:
# Сохранение индексов
w2v_index.save('w2v_index')
ft_index.save('ft_index')

True

In [19]:
def get_response(question, index, model, index_map, count_answer=3):
    question = preprocess_txt(question)
    vector = np.zeros(300)
    norm = 0
    for word in question:
        if word in model.wv:
            vector += model.wv[word]
            norm += 1
    if norm > 0:
        vector = vector / norm
    answers = index.get_nns_by_vector(vector, count_answer)
    return [index_map[i] for i in answers]