In [1]:
import sys
import os
NOTEBOOK_DIR = os.getcwd()
sys.path.append(os.path.abspath(os.path.join(NOTEBOOK_DIR, '..')))

In [2]:
import nltk
import pymorphy3

from utils import dataset_utils
from utils import metrics

# Direct search by keywords

## Load Data

In [3]:
REQUESTS_FILE = "../data/request_db.txt"
ADS_FILE = "../data/ads_db.txt"
MATCHING_FILE = "../data/matching_db.txt"

In [4]:
with open(ADS_FILE, encoding="utf-8") as f:
    ads_raw = f.readlines()

In [5]:
with open(REQUESTS_FILE, encoding="utf-8") as f:
    requests_raw = f.readlines()

In [6]:
true_markup = dataset_utils.load_matching_data(MATCHING_FILE)

## Preprocessing

In [7]:
def preprocess(text):
    text = text.replace("\\n", "\n").replace("\n", " ")
    text = text.strip()
    text = text.lower()  # many words have vectors only in lowercase
    return text


def tokenize(text):
    return nltk.tokenize.word_tokenize(text)


def remove_stop_words(tokens, stop_word_list):
    return [tok for tok in tokens if tok not in stop_word_list]


def make_normal_forms(morph, tokens):
    return [morph.parse(tok)[0].normal_form for tok in tokens]

In [8]:
nltk.download("stopwords", download_dir=os.path.join(NOTEBOOK_DIR, "../.venv/nltk_data"))

True

In [9]:
nltk.download("punkt_tab", download_dir=os.path.join(NOTEBOOK_DIR, "../.venv/nltk_data"))

True

In [10]:
rus_stop_words = nltk.corpus.stopwords.words("russian")

In [11]:
morph = pymorphy3.MorphAnalyzer()

In [12]:
ad_tokens = [make_normal_forms(morph, remove_stop_words(tokenize(preprocess(text)), rus_stop_words)) for text in ads_raw]

In [13]:
req_tokens = [make_normal_forms(morph, remove_stop_words(tokenize(preprocess(text)), rus_stop_words)) for text in requests_raw]

## Prediction

In [14]:
def predict_by_keywords(req_tokens, ad_tokens):
    predictions = {}
    for req_id, req_tok_list in enumerate(req_tokens, start=1):
        found_list = []
        for ad_id, ad_tok_list in enumerate(ad_tokens, start=1):
            if all(req_tok in ad_tok_list for req_tok in req_tok_list):
                found_list.append(str(ad_id))
        if len(found_list) > 0:
            predictions[str(req_id)] = found_list.copy()
    return predictions

In [15]:
pred_markup = predict_by_keywords(req_tokens, ad_tokens)

In [16]:
confusion_matrix = metrics.calc_confusion_matrix(true_markup, pred_markup, n_ads=len(ads_raw), n_requests=len(requests_raw))
confusion_matrix

{'TP': 81, 'FP': 40, 'TN': 87188, 'FN': 513}

In [17]:
stats = metrics.calc_all_stats(confusion_matrix)
stats

{'accuracy': 0.9937031723258409,
 'precision': 0.6694214876033058,
 'recall': 0.13636363636363635,
 'f1': 0.22657342657342658}

In [18]:
metrics.compare_with_saved_stats(stats, confusion_matrix)

-----------------------------------------------------------------------------------------
|	Metric		|	Old Value	|	New Value	|	Diff	|
-----------------------------------------------------------------------------------------
|	TP		|	73		|	81		|	📈 8	|
|	FP		|	524		|	40		|	📉 -484	|
|	TN		|	86701		|	87188		|	📈 487	|
|	FN		|	524		|	513		|	📉 -11	|
|	Prec		|	0.122		|	0.669		|	📈 0.547	|
|	Recall		|	0.122		|	0.136		|	📈 0.014	|
|	F1		|	0.122		|	0.227		|	📈 0.104	|

F1 📈 increased by 0.104, up to 22.7%, which is a significant growth 🚀
