# load data

In [1]:
import io
import requests
import polars as pl
import numpy as np

from zipfile import ZipFile


resp = requests.get("https://dl.fbaipublicfiles.com/glue/data/QQP-clean.zip")

In [2]:
arch = ZipFile(io.BytesIO(resp.content))

schema = {'id': pl.Int64, 
          'id_left': pl.Int64, 
          'id_right': pl.Int64, 
          'text_left': pl.String, 
          'text_right': pl.String, 
          'label': pl.Int8}

def read_tsv(path, schema):
    data = []
    first = True
    with arch.open(path) as zp:
        for line in zp.readlines():
            if first:
                first = False
            else:
                data.append(line.decode('utf-8').strip().split("\t"))
        data = pl.DataFrame(data=data, schema=schema)
    return data

In [3]:
train = read_tsv('QQP/train.tsv', schema)
test = read_tsv('QQP/dev.tsv', schema)

  data = pl.DataFrame(data=data, schema=schema)


In [4]:
test.head()

id,id_left,id_right,text_left,text_right,label
i64,i64,i64,str,str,i8
201359,303345,303346,"""Why are African-Americans so b…","""Why are hispanics so beautiful…",0
263843,69383,380476,"""I want to pursue PhD in Comput…","""I handle social media for a no…",0
172974,266948,175089,"""Is there a reason why we shoul…","""What are some reasons to trave…",1
15329,29298,29299,"""Why are people so obsessed wit…","""How can a single male have a c…",0
209794,314169,314170,"""What are some good baby girl n…","""What are some good baby girl n…",0


In [5]:
idx_df = pl.concat([test.unique('id_left').select(pl.col('id_left'), pl.col('text_left')).rename({'id_left': 'idx', 'text_left': 'text'}),
                    test.unique('id_right').select(pl.col('id_right'), pl.col('text_right')).rename({'id_right': 'idx', 'text_right': 'text'})
          ]).unique()

In [6]:
idx_to_text_mapping = {}
for idx, text in idx_df.iter_rows():
    idx_to_text_mapping.update({idx: text})

In [110]:
def create_val_pairs(inp_df: pl.DataFrame, fill_top_to: int = 15,
                     min_group_size: int = 2, seed: int = 0):
    # Берем только нужные столбцы
    inp_df_select = inp_df['id_left', 'id_right', 'label']
    
    # Смотрим сколько раз встречается левый вопрос
    inf_df_group_sizes = inp_df_select.group_by('id_left').agg(pl.col('id_right').count())
    
    # Берем индексы вопросов, которые больше порога
    glue_dev_leftids_to_use = inf_df_group_sizes.filter(pl.col('id_right') > min_group_size)['id_left'].to_list()
    
    # Оставляем только вопросы, прошедшие порог и группируем
    groups = inp_df_select.filter(pl.col('id_left').is_in(glue_dev_leftids_to_use)).group_by('id_left')

    all_ids = set(inp_df['id_left']).union(set(inp_df['id_right']))

    out_pairs = []

    np.random.seed(seed)

    # Итерируемся по получившемуся датасету
    for id_left, (_, id_right, label) in groups:
        id_left = id_left[0]
        # ID ПРАВОГО запроса, который является дубликатом ЛЕВОГО
        ones_ids = id_right.filter(label > 0)
        # ID ПРАВОГО запроса, который НЕ является дубликатом ЛЕВОГО
        zeroes_ids = id_right.filter(label == 0)
        sum_len = len(ones_ids) + len(zeroes_ids)
        # Считаем сколько не достает до максимального числа примеров
        num_pad_items = max(0, fill_top_to - sum_len)
        if num_pad_items > 0:
            # Рандомно выбираем из общего множества ID, которые не встречаются в этом примере
            cur_chosen = set(ones_ids).union(
                set(zeroes_ids)).union({id_left})
            pad_sample = np.random.choice(
                list(all_ids - cur_chosen), num_pad_items, replace=False).tolist()
        else:
            pad_sample = []
        # Формируем итоговые список
        # 2 - дубликат; 1 - похож но не дубликат, 2 - вообще мимо
        for i in ones_ids:
            out_pairs.append([id_left, i, 2])
        for i in zeroes_ids:
            out_pairs.append([id_left, i, 1])
        for i in pad_sample:
            out_pairs.append([id_left, i, 0])
    return out_pairs

In [261]:
idx_triplets = create_val_pairs(test)

# indexing

In [132]:
from elasticsearch import Elasticsearch
from elasticsearch.helpers import bulk

In [133]:
es = Elasticsearch(["http://localhost:9200"])

# Название индекса
index_name = "validation"

In [134]:
def create_index(es, index_name):
    """Создает индекс в Elasticsearch с базовой схемой."""
    if es.indices.exists(index=index_name):
        print(f"Индекс {index_name} уже существует.")
        return

    mapping = {
        "mappings": {
            "properties": {
                "question": {
                    "type": "text",
                    "analyzer": "standard", 
                    "search_analyzer": "english" 
                },
                "index": {
                    "type": "keyword"
                }
            }
        }
    }

    es.indices.create(index=index_name, body=mapping)
    print(f"Индекс {index_name} создан.")

In [135]:
create_index(es, index_name)

Индекс validation создан.


In [137]:
from typing import Iterable

def add_docs_to_elasticsearch(es, index_name, docs: Iterable, colnames):
    """Индексирует данные  в Elasticsearch."""
    actions = [
        {
            "_index": index_name,
            "_source": {col: row[col] for col in colnames}
        }
        for row in docs
    ]

    bulk(es, actions)
    print(f"Данные проиндексированы в {index_name}.")

In [145]:
add_docs_to_elasticsearch(es, index_name, 
                          [{'index': k, 'question': v} for k, v in idx_to_text_mapping.items()], ['index', 'question'])

Данные проиндексированы в validation.


In [176]:
def fuzzy_search(query, size=10):
    """Выполняет нечеткий поиск по индексу."""
    body = {
        "size": size,
        "query": {
            "match": {
                "question": {
                    "query": query,
                    "fuzziness": "AUTO"
                }
            }
        }
    }
    response = es.search(index=index_name, body=body)
    return response["hits"]["hits"]

def match_fuzzy_search(es, index_name, query):
    """Выполняет полнотекстовый и нечеткий поиск по индексу."""
    body = {
        "query": {
            "bool": {
                "should": [
                    {
                        "match": {
                            "question": {
                                "query": query,
                                "fuzziness": "AUTO"
                            }
                        }
                    },
                    {
                        "match": {
                            "question": {
                                "query": query
                            }
                        }
                    }
                ]
            }
        }
    }
    response = es.search(index=index_name, body=body)
    return response["hits"]["hits"]

In [149]:
query = "How to install Pithon?"

In [152]:
results = match_fuzzy_search(es, index_name, query)

print("Результаты поиска:")
for result in results:
    print(idx_to_text_mapping[result["_source"]["index"]])

Результаты поиска:
How can I install python/Django on my Mac?
How do I install a python program on a random PC?
How do I install Python 2.7.6 and Django 1.6 on Windows 7?
Why should I learn Python instead of Java?
How do I learn Python?
How do I learn Python?
I want to use Python 3.5 instead of 2.7. How do I do this by default?
How should I start learning Python?
How should I begin learning Python?
Can I use Python instead of PHP for web development?


In [153]:
results = fuzzy_search(es, index_name, query, size=15)

print("Результаты поиска:")
for result in results:
    print(idx_to_text_mapping[result["_source"]["index"]])

Результаты поиска:
How can I install python/Django on my Mac?
How do I install a python program on a random PC?
How do I install Python 2.7.6 and Django 1.6 on Windows 7?
Why should I learn Python instead of Java?
Can I use Python instead of PHP for web development?
Which is the best way to install Python and Django on Windows 7?
I want to use Python 3.5 instead of 2.7. How do I do this by default?
How do I learn Python?
How do I learn Python?
How can I run Python 2.7 code if I have Python 3.4 installed?
How should I start learning Python?
How should I begin learning Python?
How do I install Microsoft office?
How easy is it to learn Python?
How is the python course on Coursera?


# Esaticsearch validation

In [154]:
import math

def ndcg_k(ys_true: np.array, ys_pred: np.array, ndcg_top_k: int = 10) -> float:
    def dcg(ys_true, ys_pred):
        argsort = np.argsort(ys_pred)[::-1]
        argsort = argsort[:ndcg_top_k]
        ys_true_sorted = ys_true[argsort]
        ret = 0
        for i, l in enumerate(ys_true_sorted, 1):
            ret += (2 ** l - 1) / math.log2(1 + i)
        return ret
    ideal_dcg = dcg(ys_true, ys_true)
    pred_dcg = dcg(ys_true, ys_pred)
    return (pred_dcg / ideal_dcg)

In [262]:
val_dataset = pl.DataFrame({'id_left': [i[0] for i in idx_triplets],
                              'id_right': [i[1] for i in idx_triplets],
                              'label': [i[2] for i in idx_triplets]})

In [263]:
val_dataset = val_dataset.group_by('id_left').agg(pl.col('id_right'), pl.col('label'))

In [264]:
from tqdm import tqdm
ndcg = []
for id_left, id_right, label in tqdm(val_dataset.iter_rows(), total=len(val_dataset)):
    elastic_answer = {i['_source']['index']: i['_score'] for i in fuzzy_search(idx_to_text_mapping[id_left], size=10000)}
    ys_pred = np.array([elastic_answer[i] if i in elastic_answer else 0 for i in id_right])
    ndcg.append(ndcg_k(np.array(label), ys_pred))

INFO:elastic_transport.transport:POST http://localhost:9200/validation/_search [status:200 duration:0.189s]00<?, ?it/s]
INFO:elastic_transport.transport:POST http://localhost:9200/validation/_search [status:200 duration:0.167s],  4.78it/s]
INFO:elastic_transport.transport:POST http://localhost:9200/validation/_search [status:200 duration:0.196s],  5.18it/s]
INFO:elastic_transport.transport:POST http://localhost:9200/validation/_search [status:200 duration:0.106s],  4.93it/s]
INFO:elastic_transport.transport:POST http://localhost:9200/validation/_search [status:200 duration:0.160s],  5.93it/s]
INFO:elastic_transport.transport:POST http://localhost:9200/validation/_search [status:200 duration:0.158s],  5.83it/s]
INFO:elastic_transport.transport:POST http://localhost:9200/validation/_search [status:200 duration:0.158s],  5.81it/s]
INFO:elastic_transport.transport:POST http://localhost:9200/validation/_search [status:200 duration:0.158s],  5.79it/s]
INFO:elastic_transport.transport:POST ht

In [265]:
print('Elasticsearch NDCG@10', np.mean(ndcg))

Elasticsearch NDCG@10 0.9793075821794981


# KNRM validation

In [229]:
import sys
import os
import logging

sys.path.insert(0, os.path.join(os.getcwd(), os.pardir, 'ml_service'))

In [237]:
artifacts_dir = os.path.join(os.getcwd(), os.pardir, os.pardir, 'additional_data')

In [239]:
os.environ['EMB_PATH_GLOVE'] = os.path.join(artifacts_dir, 'glove.6B.50d.txt')
os.environ['EMB_PATH_KNRM'] = os.path.join(artifacts_dir, 'embeddings.bin')
os.environ['MLP_PATH'] = os.path.join(artifacts_dir, 'knrm_mlp.bin')
os.environ['VOCAB_PATH'] = os.path.join(artifacts_dir, 'vocab.json')

In [None]:
import main
from importlib import reload
reload(main)

In [None]:
rerank_model = main.Project()

In [272]:
import torch
from typing import List, Dict, Union, Callable

def collate_fn(batch_objs: List[Union[Dict[str, torch.Tensor], torch.FloatTensor]]):
    max_len_q1 = -1
    max_len_d1 = -1
    max_len_q2 = -1
    max_len_d2 = -1

    is_triplets = False
    for elem in batch_objs:
        if len(elem) == 3:
            left_elem, right_elem, label = elem
            is_triplets = True
        else:
            left_elem, label = elem

        max_len_q1 = max(len(left_elem['query']), max_len_q1)
        max_len_d1 = max(len(left_elem['document']), max_len_d1)
        if len(elem) == 3:
            max_len_q2 = max(len(right_elem['query']), max_len_q2)
            max_len_d2 = max(len(right_elem['document']), max_len_d2)

    q1s = []
    d1s = []
    q2s = []
    d2s = []
    labels = []

    for elem in batch_objs:
        if is_triplets:
            left_elem, right_elem, label = elem
        else:
            left_elem, label = elem

        pad_len1 = max_len_q1 - len(left_elem['query'])
        pad_len2 = max_len_d1 - len(left_elem['document'])
        if is_triplets:
            pad_len3 = max_len_q2 - len(right_elem['query'])
            pad_len4 = max_len_d2 - len(right_elem['document'])

        q1s.append(left_elem['query'] + [0] * pad_len1)
        d1s.append(left_elem['document'] + [0] * pad_len2)
        if is_triplets:
            q2s.append(right_elem['query'] + [0] * pad_len3)
            d2s.append(right_elem['document'] + [0] * pad_len4)
        labels.append([label])
    q1s = torch.LongTensor(q1s)
    d1s = torch.LongTensor(d1s)
    if is_triplets:
        q2s = torch.LongTensor(q2s)
        d2s = torch.LongTensor(d2s)
    labels = torch.FloatTensor(labels)

    ret_left = {'query': q1s, 'document': d1s}
    if is_triplets:
        ret_right = {'query': q2s, 'document': d2s}
        return ret_left, ret_right, labels
    else:
        return ret_left, labels
        
class RankingDataset(torch.utils.data.Dataset):
    def __init__(self, index_pairs_or_triplets: List[List[Union[str, float]]],
                 idx_to_text_mapping: Dict[str, str], vocab: Dict[str, int], oov_val: int,
                 preproc_func: Callable, max_len: int = 30):
        self.index_pairs_or_triplets = index_pairs_or_triplets
        self.idx_to_text_mapping = idx_to_text_mapping
        self.vocab = vocab
        self.oov_val = oov_val
        self.preproc_func = preproc_func
        self.max_len = max_len

    def __len__(self):
        return len(self.index_pairs_or_triplets)

    def _tokenized_text_to_index(self, tokenized_text: List[str]) -> List[int]:
        return [self.vocab.get(t, self.oov_val) for t in tokenized_text]

    def _convert_text_idx_to_token_idxs(self, idx: int) -> List[int]:
        text = self.idx_to_text_mapping[idx]
        preproc_text = self.preproc_func(text)
        return self._tokenized_text_to_index(preproc_text)

    def __getitem__(self, idx):
        query_doc_label = self.index_pairs_or_triplets[idx]
        left_elem = {}
        left_elem['query'] = self._convert_text_idx_to_token_idxs(query_doc_label[0])
        left_elem['document'] = self._convert_text_idx_to_token_idxs(query_doc_label[1])
        label = torch.tensor(query_doc_label[2])
        return left_elem, label

In [273]:
ds = RankingDataset(
                index_pairs_or_triplets=idx_triplets,
                idx_to_text_mapping=idx_to_text_mapping,
                vocab=rerank_model.vocab,
                oov_val=rerank_model.vocab['OOV'],
                preproc_func=rerank_model.simple_preproc
)

In [275]:
dl = torch.utils.data.DataLoader(
        ds, batch_size=256, num_workers=0,
        collate_fn=collate_fn, shuffle=False)

In [313]:
def valid(model: torch.nn.Module, val_dataloader: torch.utils.data.DataLoader) -> float:
    with torch.no_grad():
        labels_and_groups = val_dataloader.dataset.index_pairs_or_triplets
        labels_and_groups = pl.DataFrame({'id_left': [i[0] for i in labels_and_groups],
                              'id_right': [i[1] for i in labels_and_groups],
                              'rel': [i[2] for i in labels_and_groups]})

        all_preds = []
        for batch in (val_dataloader):
            inp_1, y = batch
            preds = model.model.predict(inp_1)
            preds_np = preds.detach().numpy()
            all_preds.append(preds_np)
        all_preds = np.concatenate(all_preds, axis=0)
        labels_and_groups = labels_and_groups.with_columns(preds = all_preds)

        ndcgs = []
        for cur_id in labels_and_groups['id_left'].unique():
            cur_df = labels_and_groups.filter(pl.col('id_left') == cur_id)
            ndcg = ndcg_k(cur_df['rel'].to_numpy().reshape(-1), cur_df['preds'].to_numpy().reshape(-1))
            if np.isnan(ndcg):
                ndcgs.append(0)
            else:
                ndcgs.append(ndcg)
        return np.mean(ndcgs)

In [316]:
print('KNRM NDCG@10', valid(rerank_model, dl).item()) # xD

KNRM NDCG@10 0.9396889693622675
