<font size=10> Using embeddings for similarity search </font>

Let’s suppose we had a large collection of questions and answers. A user can ask a question, and we want to retrieve the most similar question in our collection to help them find an answer.

We could use text embeddings to allow for retrieving similar questions:

During indexing, each question is run through a sentence embedding model to produce a numeric vector.
When a user enters a query, it is run through the same sentence embedding model to produce a vector. To rank the responses, we calculate the vector similarity between each question and the query vector. When comparing embedding vectors, it is common to use cosine similarity.
This notebook gives a simple example of how this could be accomplished in Elasticsearch. The main script indexes ~20,000 questions from the StackOverflow dataset, then allows the user to enter free-text queries against the dataset.

In [1]:
import json
import time

from elasticsearch import Elasticsearch
from elasticsearch.helpers import bulk

# Use tensorflow 1 behavior to match the Universal Sentence Encoder
# examples (https://tfhub.dev/google/universal-sentence-encoder/2).
import tensorflow.compat.v1 as tf

#For proper memory usage of GPU
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
     try:
    # Currently, memory growth needs to be the same across GPUs
        for gpu in gpus:
          tf.config.experimental.set_memory_growth(gpu, True)
          logical_gpus = tf.config.experimental.list_logical_devices('GPU')
          print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
     except RuntimeError as e:
    # Memory growth must be set before GPUs have been initialized
        print(e)

1 Physical GPUs, 1 Logical GPUs


In [2]:
##### INDEXING #####

def index_data():
    print("Creating the 'posts' index.")
    client.indices.delete(index=INDEX_NAME, ignore=[404])

    with open(INDEX_FILE) as index_file:
        source = index_file.read().strip()
        client.indices.create(index=INDEX_NAME, body=source)

    docs = []
    count = 0

    with open(DATA_FILE) as data_file:
        for line in data_file:
            line = line.strip()

            doc = json.loads(line)
            if doc["type"] != "question":
                continue

            docs.append(doc)
            count += 1

            if count % BATCH_SIZE == 0:
                index_batch(docs)
                docs = []
                print("Indexed {} documents.".format(count))

        if docs:
            index_batch(docs)
            print("Indexed {} documents.".format(count))

    client.indices.refresh(index=INDEX_NAME)
    print("Done indexing.")

def index_batch(docs):
    titles = [doc["title"] for doc in docs]
    title_vectors = embed_text(titles)

    requests = []
    for i, doc in enumerate(docs):
        request = doc
        request["_op_type"] = "index"
        request["_index"] = INDEX_NAME
        request["title_vector"] = title_vectors[i]
        requests.append(request)
    bulk(client, requests)

In [3]:
##### SEARCHING #####

def run_query_loop():
    for i in range (5): #т.к. прерывания не работают, делаю 5 запусков функции поиска
        try:
            handle_query()
        except KeyboardInterrupt:
            break

def handle_query():
    query = input("Enter query: ")

    embedding_start = time.time()
    query_vector = embed_text([query])[0]
    embedding_time = time.time() - embedding_start

    script_query = {
        "script_score": {
            "query": {"match_all": {}},
            "script": {
                "source": "cosineSimilarity(params.query_vector, doc['title_vector']) + 1.0",
                "params": {"query_vector": query_vector}
            }
        }
    }

    search_start = time.time()
    response = client.search(
        index=INDEX_NAME,
        body={
            "size": SEARCH_SIZE,
            "query": script_query,
            "_source": {"includes": ["title", "body"]}
        }
    )
    search_time = time.time() - search_start

    print()
    print("{} total hits.".format(response["hits"]["total"]["value"]))
    print("embedding time: {:.2f} ms".format(embedding_time * 1000))
    print("search time: {:.2f} ms".format(search_time * 1000))
    for hit in response["hits"]["hits"]:
        print("id: {}, score: {}".format(hit["_id"], hit["_score"]))
        print(hit["_source"])
        print()

In [4]:
##### EMBEDDING #####

def embed_text(text):
    vectors = session.run(embeddings, feed_dict={text_ph: text})
    return [vector.tolist() for vector in vectors]


In [5]:
##### MAIN SCRIPT #####

import tensorflow_hub as hub
tf.disable_eager_execution()

if __name__ == '__main__':
    print('name=main')
    
    INDEX_NAME = "posts"
    INDEX_FILE = "data/posts/index.json"

    DATA_FILE = "data/posts/posts.json"
    BATCH_SIZE = 1000

    SEARCH_SIZE = 5

    GPU_LIMIT = 0.1
    
    print("Downloading pre-trained embeddings from tensorflow hub...")
    embed = hub.load("https://tfhub.dev/google/universal-sentence-encoder/4")
    text_ph = tf.placeholder(tf.string)
    embeddings = embed(text_ph)
    print("Done.")
    
    
    
    print("Creating tensorflow session...")
    config = tf.ConfigProto()
    config.gpu_options.per_process_gpu_memory_fraction = GPU_LIMIT
   
    session = tf.Session(config=config)
    print('running session...')
    

    
    session.run(tf.global_variables_initializer())
    #sess.run(tf.global_variables_initializer())
    print('ran session...')
    session.run(tf.tables_initializer())
    print("Done.")

    client = Elasticsearch()
    '''
    index_data()
    '''
    run_query_loop()

    print("Closing tensorflow session...")
    session.close()
    print("Done.")


name=main
Downloading pre-trained embeddings from tensorflow hub...
Instructions for updating:
If using Keras pass *_constraint arguments to layers.


Instructions for updating:
If using Keras pass *_constraint arguments to layers.


Done.
Creating tensorflow session...
running session...
ran session...
Done.
Enter query: cuda





10000 total hits.
embedding time: 1630.90 ms
search time: 40.02 ms
id: cjDJvnQBox4EoX0F53Qs, score: 1.8136228
{'title': 'raytracing with CUDA', 'body': "I'm currently implementing a raytracer. Since raytracing is extremely computation heavy and since I am going to be looking into CUDA programming anyway, I was wondering if anyone has any experience with combining the two. I can't really tell if the computational models match and I would like to know what to expect. I get the impression that it's not exactly a match made in heaven, but a decent speed increasy would be better than nothing. "}

id: YjDKvnQBox4EoX0FAYfg, score: 1.5882075
{'title': 'CUDA global (as in C) dynamic arrays allocated to device memory', 'body': "So, im trying to write some code that utilizes Nvidia's CUDA architecture. I noticed that copying to and from the device was really hurting my overall performance, so now I am trying to move a large amount of data onto the device. As this data is used in numerous functio

Enter query: 4

10000 total hits.
embedding time: 12.50 ms
search time: 24.52 ms
id: 9TDKvnQBox4EoX0FJ6C2, score: 1.337199
{'title': 'Microsoft.ApplicationBlocks.Data.ODBCHelper?', 'body': "I've found mention of a data application block existing for ODBC, but can't seem to find it anywhere. If i didn't have a copy of the Access DB application block I wouldn't believe it ever existed either. Anyone know where to download either the DLL or the code-base from? --UPDATE: It is NOT included in either the v1, v2, or Enterprise Library versions of the Data ApplicationBlocks Thanks, Brian Swanson "}

id: NTDKvnQBox4EoX0FN62g, score: 1.3041186
{'title': 'AJAX-Framework', 'body': 'Which Ajax framework/toolkit can you recommend for building the GUI of web applications that are using struts? '}

id: tDDKvnQBox4EoX0FAYfg, score: 1.302273
{'title': 'ActiveRecord#save_only_valid_attributes', 'body': 'I\'m looking for a variation on the #save method that will only save attributes that do not have erro

In [6]:
import pandas as pd

In [8]:
posts_data=pd.read_json('data/posts/posts.json', lines=True)

In [9]:
posts_data.head()

Unnamed: 0,user,tags,questionId,creationDate,title,acceptedAnswerId,type,body,answerId
0,8,"[c#, winforms, type-conversion, decimal, opacity]",4,2008-07-31T21:42:52.667,While applying opacity to a form should we use...,7.0,question,I want to use a track-bar to change a form's o...,
1,9,"[html, css, css3, internet-explorer-7]",6,2008-07-31T22:08:08.620,Percentage width child element in absolutely p...,31.0,question,I have an absolutely positioned div containing...,
2,9,,4,2008-07-31T22:17:57.883,,,answer,An explicit cast to double isn't necessary. do...,7.0
3,1,"[c#, .net, datetime]",9,2008-07-31T23:40:59.743,Calculate age in C#,1404.0,question,Given a DateTime representing a person's birth...,
4,1,"[c#, datetime, time, datediff, relative-time-s...",11,2008-07-31T23:55:37.967,Calculate relative time in C#,1248.0,question,"Given a specific DateTime value, how do I disp...",


In [10]:
posts_data.shape

(100000, 9)

In [11]:
posts_data.type.describe()

count     100000
unique         2
top       answer
freq       81152
Name: type, dtype: object

Script adds to index only questions (ES index size has about 20 000 docs)