# Imports & Setup

In [46]:
import logging
import math
import os
import random
import re
import string
import urllib3
from typing import List, Dict, Union, Tuple, Optional

import faiss
import networkx as nx
import nltk
import numpy as np
import pandas as pd
import torch
from datasets import load_dataset
from nltk.corpus import stopwords
from sentence_transformers import (
    LoggingHandler,
    SentencesDataset,
    SentenceTransformer,
    losses,
    models,
    util,
)
from sentence_transformers.cross_encoder import CrossEncoder
from sentence_transformers.cross_encoder.evaluation import CECorrelationEvaluator
from sentence_transformers.evaluation import (
    BinaryClassificationEvaluator,
    EmbeddingSimilarityEvaluator,
)
from sentence_transformers.readers import InputExample, STSBenchmarkDataReader
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm

urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
STOPWORDS = set(stopwords.words("english"))

In [2]:
logging.basicConfig(
    format='%(asctime)s - %(message)s',
    datefmt='%Y-%m-%d %H:%M:%S',
    level=logging.INFO,
    handlers=[LoggingHandler()]
)
logger = logging.getLogger(__name__)

In [3]:
SESSION_PATH = "./"
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

biencoder_name = "facebook/mcontriever"
crossencoder_name = "jeffwan/mmarco-mMiniLMv2-L12-H384-v1"

## Preprocessing

In [50]:
class DatasetProcessor:
    @staticmethod
    def sts_to_dataframe(data, output_file: Optional[str] = None, save: bool = False) -> pd.DataFrame:
        out = []
        with tqdm(total=len(data), desc="Normalizing Scores") as progressbar:
            for sample in data:
                sentence1 = sample["sentence1"]
                sentence2 = sample["sentence2"]
                label = float(sample["similarity_score"])/5.0
                out.append([sentence1, sentence2, label])
                progressbar.update()

        df = pd.DataFrame(out, columns=["sentence1", "sentence2", "label"])

        if save:
            df.to_csv(os.path.join(SESSION_PATH, name))

        return df

    @staticmethod
    def qoura_to_dataframe(data, output_file: Optional[str] = None, save: bool = False) -> pd.DataFrame:
        out = []
        with tqdm(total=len(data)) as progressbar:
            for sample in data:
                sentence1 = sample["questions"]["text"][0]
                sentence2 = sample["questions"]["text"][1]
                out.append([sentence1, sentence2])
                progressbar.update()
                
        df = pd.DataFrame(out, columns=["sentence1", "sentence2"])

        if save:
            df.to_csv(os.path.join(SESSION_PATH, name))

        return df

    @staticmethod
    def filter_text(text):
        if text.isdigit():
            return True
    
        if all(char in string.punctuation for char in text):
            return True
    
        if any(len(word) > 15 for word in text.split()):
            return True
    
        words = text.split()
        if all(word.lower() in STOPWORDS for word in words):
            return len(words) > 0
        return False

    @staticmethod
    def filter_quora(df):
        mask = df.apply(lambda row: DatasetProcessor.filter_text(row["sentence1"]) or DatasetProcessor.filter_text(row["sentence2"]), axis=1)
        return df[~mask].reset_index(drop=True)

    @staticmethod
    def create_sbert_dataset(df: Union[pd.DataFrame, str], negative_sampling=True, label=True):
        examples = []
        if type(df) == str:
            data = pd.read_csv(path)
        else:
            data = df.copy()

        with tqdm(total=len(data)) as progressbar:
            for idx, row in data.iterrows():
                sentence1 = row["sentence1"]
                sentence2 = row["sentence2"]
                label = row["label"]
                
                assert type(label) == float
                assert type(sentence1) == str
                assert type(sentence2) == str
                
                inp_example = InputExample(
                    texts=[sentence1, sentence2],
                    label=label
                )
    
                examples.append(inp_example)
                progressbar.update()
                
        return examples

    @staticmethod
    def negative_sample(df: pd.DataFrame, n: int, to_df: bool = True):
        G = nx.Graph()
        for index, row in df.iterrows():
            G.add_edge(row["sentence1"], row["sentence2"])

        all_sentences = list(set(df["sentence1"]).union(set(df["sentence2"])))
        negative_samples = [DatasetProcessor.generate_negative_sample(G, all_sentences) for _ in tqdm(range(n))]
        if to_df:
            negative_samples = pd.DataFrame(negative_samples, columns=["sentence1", "sentence2"])
            negative_samples["label"] = 0.0
        return negative_samples

    @staticmethod
    def generate_negative_sample(G, all_sentences):
        while True:
            s1, s2 = random.sample(all_sentences, 2)
            if not G.has_edge(s1, s2) and not nx.has_path(G, s1, s2):
                return s1, s2


class DataLabeler:
    @staticmethod
    def label(cross_encoder, df):
        data = []
        with tqdm(total=len(df)) as progressbar:
            for idx, row in df.iterrows():
                sentences = [row["sentence1"], row["sentence2"]]
                data.append(sentences)
                progressbar.update()

        scores = cross_encoder.predict(data, show_progress_bar=True)
        return scores

    @staticmethod
    def create_silver(df, scores):
        df["label"] = scores.tolist()
        return df

In [5]:
%%capture
sts_dataset = load_dataset("stsb_multi_mt", name="en")

In [6]:
sts_df_train = DatasetProcessor.sts_to_dataframe(sts_dataset["train"])
sts_df_dev = DatasetProcessor.sts_to_dataframe(sts_dataset["dev"])
sts_df_test = DatasetProcessor.sts_to_dataframe(sts_dataset["test"])

Normalizing Scores:   0%|          | 0/5749 [00:00<?, ?it/s]

Normalizing Scores:   0%|          | 0/1500 [00:00<?, ?it/s]

Normalizing Scores:   0%|          | 0/1379 [00:00<?, ?it/s]

In [7]:
sts_train = DatasetProcessor.create_sbert_dataset(sts_df_train) # The gold dataset
sts_test = DatasetProcessor.create_sbert_dataset(sts_df_dev)
sts_dev = DatasetProcessor.create_sbert_dataset(sts_df_test)

  0%|          | 0/5749 [00:00<?, ?it/s]

  0%|          | 0/1500 [00:00<?, ?it/s]

  0%|          | 0/1379 [00:00<?, ?it/s]

## Training Cross-Encoder

In [10]:
%%capture
batch_size = 64
num_epochs = 5

train_dataloader = DataLoader(sts_train, shuffle=True, batch_size=batch_size)
evaluator = CECorrelationEvaluator.from_input_examples(sts_dev, name="dev")
warmup_steps = math.ceil(len(train_dataloader) * num_epochs * 0.1)

cross_encoder = CrossEncoder(
    crossencoder_name, num_labels=1,
    default_activation_function=torch.nn.Sigmoid()
)

In [15]:
cross_encoder = CrossEncoder(
    "kaggle/working/crossencoder/", num_labels=1,
    default_activation_function=torch.nn.Sigmoid()
)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


2024-04-05 21:13:59 - Use pytorch device: mps


In [None]:
cross_encoder.fit(
    train_dataloader=train_dataloader,
    evaluator=evaluator,
    epochs=num_epochs,
    evaluation_steps=100,
    warmup_steps=warmup_steps,
    output_path=os.path.join(SESSION_PATH, "crossencoder")
)

## Create Silver

In [52]:
#%%capture
from datasets import load_dataset
dataset = load_dataset("quora")
dataset = dataset["train"].train_test_split(train_size=30_000, test_size=75_000)

In [54]:
DatasetProcessor.filter_quora(quora_df_dev)

Unnamed: 0,sentence1,sentence2
0,How are bacteria beneficial to humans?,How can bacteria be harmful to humans?
1,"I am turning 30 this week, and I am lost caree...",Why do some people have a tenacious personality?
2,Which is the best site to play indian rummy wi...,Which is the best site to play an Indian rummy?
3,What is your most common problem?,What are the most common problems in teaching?
4,Can we build cars that run on tap water?,Why don't we have cars that run on water?
...,...,...
73214,Yale University: What was it like to attend Ya...,Yale University: What was it like to attend Ya...
73215,How is the formula for copper II sulfate hydra...,What is the formula for copper II sulfate?
73216,How do I prove the existence of God to an athe...,How can you prove to someone that GOD exists?
73217,Why has eBay failed?,Why did eBay fail in China?


In [53]:
quora_df_train = DatasetProcessor.qoura_to_dataframe(dataset["train"])
quora_df_dev = DatasetProcessor.qoura_to_dataframe(dataset["test"])

  0%|          | 0/30000 [00:00<?, ?it/s]

  0%|          | 0/75000 [00:00<?, ?it/s]

In [None]:
train_labels = DataLabeler.label(cross_encoder, quora_df_train)
dev_labels = DataLabeler.label(cross_encoder, quora_df_dev)

quora_df_train = DataLabeler.create_silver(quora_df_train, train_labels)
quora_df_dev = DataLabeler.create_silver(quora_df_dev, dev_labels)

In [15]:
negative_samples_train = DatasetProcessor.negative_sample(quora_df_train, 5000)
negative_samples_dev = DatasetProcessor.negative_sample(quora_df_dev, 5000)

silver_train = pd.concat([quora_df_train, sts_df_train, negative_samples_train]).reset_index(drop=True)
silver_dev = pd.concat([quora_df_dev, sts_df_dev, negative_samples_dev]).reset_index(drop=True)

  0%|          | 0/5000 [00:00<?, ?it/s]

  0%|          | 0/5000 [00:00<?, ?it/s]

In [16]:
silver_train_sbert = DatasetProcessor.create_sbert_dataset(silver_train)
silver_dev_sbert = DatasetProcessor.create_sbert_dataset(silver_dev)

  0%|          | 0/40749 [00:00<?, ?it/s]

  0%|          | 0/8500 [00:00<?, ?it/s]

## Training Bi-Encoder

In [None]:
%%capture
encoder = models.Transformer(biencoder_name, max_seq_length=128)
pooling_model = models.Pooling(
    encoder.get_word_embedding_dimension(),
    pooling_mode_mean_tokens=True,
    pooling_mode_cls_token=False,
    pooling_mode_max_tokens=False
)
bi_encoder = SentenceTransformer(modules=[encoder, pooling_model])

In [None]:
batch_size = 32
train_dataloader = DataLoader(silver_train, shuffle=True, batch_size=batch_size)
train_loss = losses.CosineSimilarityLoss(model=bi_encoder)
evaluator = EmbeddingSimilarityEvaluator.from_input_examples(silver_dev, name="dev")
num_epochs = 5

warmup_steps = math.ceil(len(silver_train) * num_epochs / batch_size * 0.1)
logging.info("Warmup-steps: {}".format(warmup_steps))

bi_encoder.fit(
    train_objectives=[(train_dataloader, train_loss)],
    evaluator=evaluator,
    epochs=num_epochs,
    warmup_steps=warmup_steps,
    output_path=os.path.join(SESSION_PATH, "biencoder")
)

## Load Fine-Tuned Bi-Encoder & Encode Test Samples

In [None]:
cross_encoder = CrossEncoder(
    "kaggle/working/crossencoder/", num_labels=1,
    default_activation_function=torch.nn.Sigmoid()
)

In [20]:
%%capture
bi_encoder = SentenceTransformer("kaggle/working/biencoder")

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [56]:
quora_df_dev.to_csv("data/test.csv", index=False)

In [57]:
all_sentence1 = quora_df_dev["sentence1"].tolist()
all_sentence2 = quora_df_dev["sentence2"].tolist()
all_sentences = list(set(all_sentence1 + all_sentence2))

In [58]:
all_sentences_embedding = bi_encoder.encode(all_sentences, show_progress_bar=True)

Batches:   0%|          | 0/4015 [00:00<?, ?it/s]

In [59]:
np.save("data/embeddings.npy", all_sentences_embedding)

In [60]:
all_sentences_embedding = np.load("data/embeddings.npy")

In [61]:
all_sentences_embedding.shape

(128469, 768)

## Approximate Search

In [62]:
def get_sentence_vector(model, sentence, preprocess=True):
    return model.encode(sentence, show_progress_bar=False)

def faiss_search(model, embeddings, query, k):
    index = faiss.IndexFlatL2(embeddings.shape[1])
    index.add(embeddings)
    query_vector = get_sentence_vector(model, query, False)[:, None].T
    D, I = index.search(query_vector, k=k)
    return I

def rerank(I, query, sentences, cross_encoder):
    results = [sentences[i] for i in I[0]]
    reranking_scores = [cross_encoder.predict([query, doc], show_progress_bar=False) for doc in results]
    docs = list(zip(results, reranking_scores, range(len(results))))
    docs_sorted = list(sorted(docs, key= lambda x: x[1], reverse=True))
    return docs_sorted

In [65]:
query = "i want to trash my quora account"

I = faiss_search(bi_encoder, all_sentences_embedding, query, k=5)
rerank(I, query, all_sentences, cross_encoder)

[('How do I delete my account at Quora?', 0.6080411, 1),
 ('How do I delete Quora account?', 0.5275337, 0),
 ('How do I get a Quora Account?', 0.4865062, 2),
 ('How did I get a Quora account?', 0.44115204, 3),
 ('What does "Quora" mean?', 0.15222643, 4)]

In [66]:
query = "is there any god?"

I = faiss_search(bi_encoder, all_sentences_embedding, query, k=5)
rerank(I, query, all_sentences, cross_encoder)

[('Is there one god or many?', 0.71820015, 0),
 ('Where is god?', 0.6188236, 1),
 ('Is there any proof that God exists?', 0.5178925, 4),
 ('What is a "god"?', 0.5027556, 2),
 ('What is the meaning of one god?', 0.34638438, 3)]

In [74]:
query = "deep learning"

I = faiss_search(bi_encoder, all_sentences_embedding, query, k=5)
rerank(I, query, all_sentences, cross_encoder)

[('How is deep learning used in finance?', 0.71429634, 3),
 ('Who is the best deep learning teacher?', 0.6813439, 4),
 ('Is deep learning used in trading?', 0.6724121, 1),
 ('How does deep residual learning work?', 0.6713976, 0),
 ('Kevin Murphy: What do you think of Deep Reinforcement Learning?',
  0.6100112,
  2)]

In [83]:
query = "smoking"

I = faiss_search(bi_encoder, all_sentences_embedding, query, k=10)
rerank(I, query, all_sentences, cross_encoder)

[('What would be the best way to quit smoking?', 0.39267415, 8),
 ('What are some ways to quit smoking?', 0.35488284, 6),
 ('How do one quit smoking?', 0.34245422, 5),
 ('How do I quit smoking and drinking?', 0.340137, 7),
 ('How do you quit smoking?', 0.33898452, 3),
 ('how to quit smoking', 0.3258829, 0),
 ('How do I quit smoking?', 0.3130946, 4),
 ('How do stop smoking?', 0.3014821, 9),
 ('?', 0.09699556, 1),
 ('o', 0.09331581, 2)]

## kNN Search with Elasticsearch

In [84]:
from elasticsearch import Elasticsearch, helpers

In [85]:
ES_LOCALHOST = "http://127.0.0.1:9200"

USERNAME = "elastic"
PASSWORD = "UJa8*7BwbFaCu*y5lDzE"

client = Elasticsearch(
    "http://localhost:9200",
    http_auth=(USERNAME, PASSWORD), 
    verify_certs=False, 
    use_ssl=True
)



In [86]:
def prepare_dataset(sentences, embeddings):
    bulk = []
    for sentence, embedding in zip(sentences, embeddings):
        sample = {}
        sample["question"] = sentence
        sample["question_vector"] = embedding.tolist()
        bulk.append(sample)
    return bulk   

In [87]:
documents = prepare_dataset(all_sentences, all_sentences_embedding)

In [88]:
def bulk(documents):
    for doc in documents:
        yield {
            "_index": "inzva-index",
            "_source": doc,
        }


success, _ = helpers.bulk(client, bulk(documents), chunk_size=10000)

2024-04-05 21:53:02 - POST https://localhost:9200/_bulk [status:200 request:6.495s]
2024-04-05 21:53:10 - POST https://localhost:9200/_bulk [status:200 request:6.337s]
2024-04-05 21:53:19 - POST https://localhost:9200/_bulk [status:200 request:6.937s]
2024-04-05 21:53:27 - POST https://localhost:9200/_bulk [status:200 request:6.905s]
2024-04-05 21:53:37 - POST https://localhost:9200/_bulk [status:200 request:7.334s]
2024-04-05 21:53:46 - POST https://localhost:9200/_bulk [status:200 request:7.530s]
2024-04-05 21:53:56 - POST https://localhost:9200/_bulk [status:200 request:7.815s]
2024-04-05 21:54:09 - POST https://localhost:9200/_bulk [status:200 request:8.998s]
2024-04-05 21:54:17 - POST https://localhost:9200/_bulk [status:200 request:6.006s]
2024-04-05 21:54:26 - POST https://localhost:9200/_bulk [status:200 request:6.641s]
2024-04-05 21:54:35 - POST https://localhost:9200/_bulk [status:200 request:7.304s]
2024-04-05 21:54:44 - POST https://localhost:9200/_bulk [status:200 request:

In [89]:
post_index_settings = {
    "refresh_interval": "60s",
    "number_of_replicas": 1
}
client.indices.put_settings(
    index="inzva-index",
    body={"settings": post_index_settings},
)

2024-04-05 21:56:12 - PUT https://localhost:9200/inzva-index/_settings [status:200 request:0.375s]


{'acknowledged': True}

In [96]:
def query_to_embedding(bi_encoder, query):
    return bi_encoder.encode(query).tolist()


def knn_query(query_embedding, query, boost, num_candidates, size):
    return {
        "knn": {
            "field": "question_vector",
            "query_vector": query_embedding,
            "k": size,
            "num_candidates": num_candidates,
            "boost": boost
        }
    }


def multi_match_query(query, boost, size):
    return {
        "size": size,
        "query": {
        "bool": {
          "should": [
            {
              "multi_match": {
                "query": query_text,
                "fields": ["question"],
                "boost": boost
              }
            }
          ]
        }
      }
    }


def hybrid_query(query_embedding, query, knn_boost=1, multi_match_boost=1, num_candidates=300, size=15):
    query = {}
    
    knn = knn_query(query_embedding, query, knn_boost, num_candidates, size)
    multi_match = multi_match_query(query, multi_match_boost, size)

    query.update(knn)
    query.update(multi_match)
    return query


def search(client, query):
    output = []
    response = client.search(index="inzva-index", body=query)
    for sample in response["hits"]["hits"]:
        output.append((sample["_score"], sample["_source"]["question"]))
            
    return output

In [97]:
query_text = "spring boot"
query_embedding = query_to_embedding(bi_encoder, query_text)

query = hybrid_query(query_embedding, query, 0.9, 0.1)

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

In [100]:
search(client, query)

2024-04-05 21:58:22 - POST https://localhost:9200/inzva-index/_search [status:200 request:0.028s]


[(2.593286,
  'What is the difference between Spring Boot and the Spring framework?'),
 (2.1398466,
  'Should I learn Spring boot Framework or Spring Framework? What is the difference between them and which one is the best?'),
 (1.7295177, 'What is boot process?'),
 (1.6919584, 'What is a boot camp?'),
 (1.6444197, 'What is a boot loader?'),
 (1.6259985, 'How do you learn spring framework?'),
 (1.6195064, 'How do I learn Spring Framework?'),
 (1.5687168, 'How do I learn Spring Framework? Help?'),
 (1.5658484, 'Where can I get spring framework videos?'),
 (1.5274035, 'How is still water different from spring water?'),
 (1.4190416, 'How do I boot from SSD with clone of Windows 10?'),
 (1.0340872, 'Is the Arab Spring over?'),
 (1.0282699, 'Coding boot camps Bay area?'),
 (0.9876111, 'What exactly is the Arab Spring?'),
 (0.9876111, 'How good is Poland Spring water?')]

In [105]:
query_text = "huawei"
query_embedding = query_to_embedding(bi_encoder, query_text)

query = hybrid_query(query_embedding, query, 0.9, 0.1)

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

In [106]:
search(client, query)

2024-04-05 21:59:09 - POST https://localhost:9200/inzva-index/_search [status:200 request:0.095s]


[(1.6322346, 'Are there any functional huawei service centers in Gurgaon?'),
 (1.0528336, 'Iphone 6 plus vs Huawei p9 plus?'),
 (1.0094175, 'Which Huawei Switches Support Static Multicast Router Ports?'),
 (1.0094175, 'Did my Huawei y 360-431 sappoting vidiocall?'),
 (1.0094175, 'What is difference LTE Huawei and LTE alcatel?'),
 (0.9694406, 'Where can I find Huawei service center in Gurgaon?'),
 (0.9694406, 'For what does Huawei Mate 9 need four microphones?'),
 (0.9325094, 'When will the Huawei Honor 4X get the Lollipop update?'),
 (0.89828885, 'How do I turn off the screen overlay in Huawei P8lite?'),
 (0.86649084, 'Which is the best custom rom for the Huawei Honor 3c 4G?'),
 (0.8368672,
  'How is the after sales service of Xiaomi, Huawei and Gionee in India?'),
 (0.8092021,
  'Should I update my Huawei Honor 7 to Marshmallow from Lollipop. Why? Why not?'),
 (0.7590188,
  'Have Huawei Mate 9 Porsche design launched on the market? How/where can I reserve it?'),
 (0.6571369,
  'How do