## Семинар: Нейросетевые модели поиска. Часть I.


В этом семинаре мы изучим библиотеку [gensim](https://radimrehurek.com/gensim/auto_examples/index.html#documentation) для получения эмбеддингов слов и попробуем использовать их для обучения DSSM.

### Работа с gensim

Скачаем датасет MS MARCO:

In [1]:
import numpy as np
import pandas as pd

from tqdm.notebook import tqdm
from rank_bm25 import BM25Okapi
from datasets import load_dataset

from nltk.tokenize import WordPunctTokenizer

In [2]:
msmarco_dataset = load_dataset("Tevatron/msmarco-passage")

Считаем датасет и сконвертируем в удобный формат.

In [3]:
def dataset_pandas(dataset):
    rows = []
    for i, row in tqdm(enumerate(dataset)):
        current_row = []
        for pos_sample in row['positive_passages']:
            current_row = []
            current_row.append(i) # qid
            current_row.append(row['query']) # query
            current_row.append(pos_sample['text']) # text
            current_row.append(1.) # label
            rows.append(current_row)

        for neg_sample in row['negative_passages']:
            current_row = []
            current_row.append(i) # qid
            current_row.append(row['query']) # query
            current_row.append(neg_sample['text']) # text
            current_row.append(0.) # label
            rows.append(current_row)
    print(len(rows))

    return pd.DataFrame(rows, columns=['qid', 'query', 'text', 'label'])

In [4]:
data = dataset_pandas(msmarco_dataset['train'])

0it [00:00, ?it/s]

12346948


Разделим датасет на train / val / test. Разделяем с группировкой по сессиям (запросам).

In [5]:
DATASET_SIZE = 400_000
TEST_SIZE=3_000

test_data = data[(DATASET_SIZE  - TEST_SIZE < data['qid']) & (data['qid'] <= DATASET_SIZE)]
val_data = data[(DATASET_SIZE - 2 * TEST_SIZE < data['qid']) & (data['qid'] <= DATASET_SIZE - TEST_SIZE)]
train_data = data[data['qid'] <= DATASET_SIZE - 2 * TEST_SIZE]

In [6]:
train_data.head()

Unnamed: 0,qid,query,text,label
0,0,where is whitemarsh island,"Whitemarsh Island, Georgia. Whitemarsh Island ...",1.0
1,0,where is whitemarsh island,the strategy of island hopping was used by the...,0.0
2,0,where is whitemarsh island,"For the island near Dunedin, see White Island,...",0.0
3,0,where is whitemarsh island,"Jekyll Island, at 5,700 acres, is the smallest...",0.0
4,0,where is whitemarsh island,Sibu Island. A scuba diver at Sibu Island. Sib...,0.0


Используя библиотеку gensim, получим эмбеддинги для всех запросов и документов.

Посмотрим, какие предобученные модели уже есть в gensim:

In [7]:
import gensim.downloader

gensim.downloader.info()['models'].keys()

dict_keys(['fasttext-wiki-news-subwords-300', 'conceptnet-numberbatch-17-06-300', 'word2vec-ruscorpora-300', 'word2vec-google-news-300', 'glove-wiki-gigaword-50', 'glove-wiki-gigaword-100', 'glove-wiki-gigaword-200', 'glove-wiki-gigaword-300', 'glove-twitter-25', 'glove-twitter-50', 'glove-twitter-100', 'glove-twitter-200', '__testing_word2vec-matrix-synopsis'])

Используем fasttext-wiki-news-subwords-300:

In [8]:
embedder_model = gensim.downloader.load("fasttext-wiki-news-subwords-300")

Получим эмбеддинги запросов и документов как среднее значение эмбеддингов токенов в них.

Затем, посчитаем метрику MRR@10 и сравним с bm25.

In [9]:
def get_embedder_model_predictions(dataset):
    tokenizer = WordPunctTokenizer()

    query_embeddings = np.stack(dataset["query"].apply(lambda text: embedder_model.get_mean_vector(tokenizer.tokenize(text), ignore_missing=True)).values)
    query_embeddings = query_embeddings / np.linalg.norm(query_embeddings, axis=-1, keepdims=True)

    document_embeddings = np.stack(dataset["text"].apply(lambda text: embedder_model.get_mean_vector(tokenizer.tokenize(text), ignore_missing=True)).values)
    document_embeddings = document_embeddings / np.linalg.norm(document_embeddings, axis=-1, keepdims=True)

    scores = (query_embeddings * document_embeddings).sum(axis=-1)
    return scores

embedder_model_predictions = get_embedder_model_predictions(test_data)

In [10]:
def get_bm25_model_predictions(dataset):
    tokenizer = WordPunctTokenizer()
    tokenized_corpus = [tokenizer.tokenize(doc) for doc in dataset['text'].values]
    bm25 = BM25Okapi(tokenized_corpus)

    queries = dataset['query'].unique()
    bm25_preds = np.zeros(len(test_data))
    for query in tqdm(queries):
        tokenized_query = tokenizer.tokenize(query)
        doc_scores = bm25.get_scores(tokenized_query)
        mask = test_data['query'] == query
        bm25_preds[mask] = doc_scores[mask]
    return bm25_preds

bm25_model_predictions = get_bm25_model_predictions(test_data)

  0%|          | 0/3000 [00:00<?, ?it/s]

In [11]:
import torch
from torchmetrics.retrieval import RetrievalMRR


def MRR(preds, target, qids):
    mrr = RetrievalMRR(top_k=10)
    return mrr(torch.Tensor(preds), torch.Tensor(target),
               indexes=torch.LongTensor(qids - min(qids)))

In [12]:
print(f"BM25 MRR@10: {MRR(bm25_model_predictions, test_data['label'].values, test_data['qid'].values):4f}")
print(f"Embeddings MRR@10: {MRR(embedder_model_predictions, test_data['label'].values, test_data['qid'].values):4f}")

BM25 MRR@10: 0.317946
Embeddings MRR@10: 0.207176


Попробуем нормализовать и смешать предсказания bm25 и эмбеддингов, чтобы 

In [13]:
bm25_model_predictions_normed = (bm25_model_predictions - bm25_model_predictions.min()) / (bm25_model_predictions.max() - bm25_model_predictions.min())
embedder_model_predictions = (embedder_model_predictions - embedder_model_predictions.min()) / (embedder_model_predictions.max() - embedder_model_predictions.min())

merged_predictions = bm25_model_predictions_normed + embedder_model_predictions * 0.01
print(f"Ensamble MRR@10: {MRR(merged_predictions, test_data['label'].values, test_data['qid'].values):4f}")

Ensamble MRR@10: 0.319093


Дальнейшие пути улучшения:
* Попробовать усреднять эмбеддинги слов с весами tf-idf;
* Перебрать коэффициент усреднения предиктов bm25 и embedder_model;
* Сделать нормализацию скоров независимо в рамках каждого запроса;
* Использовать другие методы ансамблирования моделей ранжирования.