# Semantic Search with Matching Engine and PaLM Embeddings

**Learning Objectives**
  1. Learn how to create text embeddings using the Vertex PaLM API
  1. Learn how to load embeddings in Vertex Matching Engine
  2. Learn how to query Vertex Matching Engine
  1. Learn how to build an information retrieval system based on semantic match
  
  
In this notebook, we implement a simple (albeit fast and scalable) [semantic search](https://en.wikipedia.org/wiki/Semantic_search#:~:text=Semantic%20search%20seeks%20to%20improve,to%20generate%20more%20relevant%20results.) retrieval system using [Vertex Matching Engine](https://cloud.google.com/vertex-ai/docs/matching-engine/overview) and [Vertex PaLM Embeddings](https://cloud.google.com/vertex-ai/docs/generative-ai/embeddings/get-text-embeddings). In a semantic search system, a number of documents are returned to a user query, ranked by their semantic match. This means that the returned documents should match the intent or meaning of the query rather than the its actual exact  keywords as opposed to a boolean or keyword-based retrieval system. Such a semantic search system has in general two components, namely:

* A component that produces semantically meaningful vector representations of both the documents as well as the user queries; we will use the [Vertex PaLM Embeddings](https://cloud.google.com/vertex-ai/docs/generative-ai/embeddings/get-text-embeddings) API to creates these embbedings, leveraging the power of the [PaLM](https://ai.googleblog.com/2022/04/pathways-language-model-palm-scaling-to.html) large language model developped at Google. 

* A component that allows users to store the document vector embeddings and retrieve the most relevant documents by returning the documents whose embeddings are the closest to the user-query emebedding in the embedding space. We will use [Vertex Matching Engine](https://cloud.google.com/vertex-ai/docs/matching-engine/overview) which can scale up to billions of embeddings thanks to an [efficient approximate neighest neigbor strategy](https://ai.googleblog.com/2020/07/announcing-scann-efficient-vector.html) to compare and retrieve the closest document vectors to a query vector based on a [recent paper from Google research](https://arxiv.org/abs/1908.10396).



**Dataset:** We will use a very small subset of the [COVID-19 Open Research Dataset Challenge (CORD-19)
](https://www.kaggle.com/datasets/allen-institute-for-ai/CORD-19-research-challenge), which contains around 1 million of medical research papers focused on COVID 19. We will focus on only 4000 title, abstract, and url from 2021 only for the sake of speed.

## Setup 

In [8]:
import os
import json

from IPython import display
import pandas as pd
from google.cloud import aiplatform
from vertexai.language_models import TextEmbeddingModel

In [4]:
REGION = "us-central1"
PROJECT = !(gcloud config get-value core/project)
PROJECT = PROJECT[0]
BUCKET = f"{PROJECT}-cord19-matching"

# Do not change these
os.environ["PROJECT"] = PROJECT
os.environ["BUCKET"] = BUCKET
os.environ["REGION"] = REGION

In [5]:
!gsutil ls gs://{BUCKET} || gsutil mb -l {REGION} gs://{BUCKET}

BucketNotFoundException: 404 gs://dherin-dev-cord19-matching bucket does not exist.
Creating gs://dherin-dev-cord19-matching/...


## Loading the data

The dataset we will use is the title, abstract, and url metadata of a 4000 samples from the ~1 million medical papers in the [COVID-19 Open Research Dataset Challenge (CORD-19)
](https://www.kaggle.com/datasets/allen-institute-for-ai/CORD-19-research-challenge). In this lab, we use the astract as the documents, on which to compute and store the embeddings.  

In [6]:
metadata = pd.read_csv('../data/cord19_metadata_sample.csv.gz')
metadata.head()

Unnamed: 0,title,abstract,url
0,Ethnobotanical and ethnomedicinal analysis of ...,Algerian people largely rely on traditional me...,https://www.ncbi.nlm.nih.gov/pubmed/34131369/;...
1,1.9 Adolescents in Crisis: Psychological Impac...,,https://doi.org/10.1016/j.jaac.2021.09.022; ht...
2,Myopericarditis in a previously healthy adoles...,We report the case of a previously healthy 16‐...,https://www.ncbi.nlm.nih.gov/pubmed/34133825/;...
3,Religious Support as a Contribution to Face th...,Coping with the COVID-19 pandemic has required...,https://www.ncbi.nlm.nih.gov/pubmed/33405093/;...
4,The urgency of resuming disrupted dog rabies v...,OBJECTIVE: Dog vaccination is a cost-effective...,http://medrxiv.org/cgi/content/short/2021.04.2...


## Creating the embeddings

The first thing to do to create embedding vectors for our 4000 abstracts. For that, we need to first instanciate the `TextEmbeddingModel` client with the appropriate version of the PaLM model, which is `textembedding-gecko` for embeddings. It is a smaller version than for text (`text-bison`) and chat (`chat-bison`) generation, and it allows for a faster processing time:

In [9]:
model = TextEmbeddingModel.from_pretrained("textembedding-gecko@001")

The embedding model can take up to a list of 5 text to process at a single time. Because of that, we will iterate over the `metadata.abstract`'s in batches of 5 and feed these batches to `model.get_embbedings` to create the embeddings of all the 4000 abstracts, which we will then store in the list `vectors`. Running the next cell will take a couple of minutes:

In [10]:
MAX_BATCH_SIZE = 5
vectors = []

for i in range(0, len(metadata), MAX_BATCH_SIZE):
    batch = metadata.abstract[i: i + MAX_BATCH_SIZE]
    embeddings = model.get_embeddings(batch)
    vectors.extend([embedding.values for embedding in embeddings])

## Creating the matching engine input file

At this point, our 4000 abstract embeddings are stored in memory in the `vectors` list. To store these embeddings into [Vertex Matching Engine](https://cloud.google.com/vertex-ai/docs/matching-engine/overview), we need to serialized them into a JSON file with the [following format](https://cloud.google.com/vertex-ai/docs/matching-engine/match-eng-setup/format-structure):

```python
{"id": <DOCUMENT_ID1>, "embedding": [0.1, ..., -0.7]}
{"id": <DOCUMENT_ID2>, "embedding": [-0.4, ..., 0.8]}
etc.
```
where the value of the `id` field should be an indentifier allowing us to retrieve the actual document from a separate source, and the value of `embedding` is the vector returned by the PaLM API. 

For the document `id` we simply use here the row index in the `metadata` DataFrame, which will serve as our in-memory document store. This makes it particularly easy to retrieve the astract, title and url from an `id` returned by the matching engine:

```python
metadata.abstract[id]
metadata.title[id]
metadata.url[id]
```

The next cell iterates over `vectors` appending for each entry a JSON line as above to `cord19_embeddings.json` containing the index of the abstract in `metadata` as well as the embedding vector returned by PaLM:

In [12]:
embeddings_file_path = "cord19_embeddings.json"

# Removing the embedding file if it already exists
!test -f {embeddings_file_path} && rm {embeddings_file_path}

with open(embeddings_file_path, 'a') as embeddings_file:    
    for i, embedding in enumerate(vectors):
        json_line = json.dumps(
            {
                "id": i,
                "embedding": embedding
            }
        ) + '\n'
        embeddings_file.writelines(json_line)

Let us verify that our embedding file has 4000 lines, one per abstract, and then let us save it to a GCS bucket:

In [49]:
!wc -l {embeddings_file_path}

4000 cord19_embeddings.json


In [13]:
EMBEDDINGS_URI = f"gs://{BUCKET}"

!gsutil cp {embeddings_file_path} {EMBEDDINGS_URI}

Copying file://cord19_embeddings.json [Content-Type=application/json]...
- [1 files][ 64.9 MiB/ 64.9 MiB]                                                
Operation completed over 1 objects/64.9 MiB.                                     


## Creating the matching engine index

In [14]:
DISPLAY_NAME = "cord19_palm_embeddings"

matching_engine_index = aiplatform.MatchingEngineIndex.create_tree_ah_index(
    display_name=DISPLAY_NAME,
    contents_delta_uri=EMBEDDINGS_URI,
    dimensions=len(vectors[0]),
    approximate_neighbors_count=150,
    distance_measure_type="COSINE_DISTANCE",
    leaf_node_embedding_count=500,
    leaf_nodes_to_search_percent=7,
    description=DISPLAY_NAME,
)

Creating MatchingEngineIndex
Create MatchingEngineIndex backing LRO: projects/115851500182/locations/us-central1/indexes/6107013036109725696/operations/263368975138684928
MatchingEngineIndex created. Resource name: projects/115851500182/locations/us-central1/indexes/6107013036109725696
To use this MatchingEngineIndex in another session:
index = aiplatform.MatchingEngineIndex('projects/115851500182/locations/us-central1/indexes/6107013036109725696')


In [16]:
INDEX_RESOURCE_NAME = matching_engine_index.resource_name

print(INDEX_RESOURCE_NAME)

projects/115851500182/locations/us-central1/indexes/6107013036109725696


In [17]:
matching_engine_index = aiplatform.MatchingEngineIndex(index_name=INDEX_RESOURCE_NAME)

In [18]:
matching_engine_endpoint = aiplatform.MatchingEngineIndexEndpoint.create(
    display_name=DISPLAY_NAME,
    description=DISPLAY_NAME,
    public_endpoint_enabled=True,
)

Creating MatchingEngineIndexEndpoint
Create MatchingEngineIndexEndpoint backing LRO: projects/115851500182/locations/us-central1/indexEndpoints/3039850583637884928/operations/4325826945259405312
MatchingEngineIndexEndpoint created. Resource name: projects/115851500182/locations/us-central1/indexEndpoints/3039850583637884928
To use this MatchingEngineIndexEndpoint in another session:
index_endpoint = aiplatform.MatchingEngineIndexEndpoint('projects/115851500182/locations/us-central1/indexEndpoints/3039850583637884928')


In [21]:
DEPLOYED_INDEX_ID = "cord19_deployed_index"

In [22]:
matching_engine = matching_engine_endpoint.deploy_index(
    index=matching_engine_index, deployed_index_id=DEPLOYED_INDEX_ID
)

matching_engine.deployed_indexes

Deploying index MatchingEngineIndexEndpoint index_endpoint: projects/115851500182/locations/us-central1/indexEndpoints/3039850583637884928
Deploy index MatchingEngineIndexEndpoint index_endpoint backing LRO: projects/115851500182/locations/us-central1/indexEndpoints/3039850583637884928/operations/7558285577804578816
MatchingEngineIndexEndpoint index_endpoint Deployed index. Resource name: projects/115851500182/locations/us-central1/indexEndpoints/3039850583637884928


[id: "cord19_deployed_index"
index: "projects/115851500182/locations/us-central1/indexes/6107013036109725696"
create_time {
  seconds: 1692234125
  nanos: 911413000
}
index_sync_time {
  seconds: 1692235063
  nanos: 407296000
}
automatic_resources {
  min_replica_count: 2
  max_replica_count: 2
}
deployment_group: "default"
]

## Querying Matching Engine

In [44]:
QUERY = "prophylactic measures"

text_embeddings = [
    vector.values 
    for vector in model.get_embeddings([QUERY])
]

In [45]:
# Define number of neighbors to return
NUM_NEIGHBORS = 10

response = matching_engine.find_neighbors(
    deployed_index_id=DEPLOYED_INDEX_ID,
    queries=text_embeddings,
    num_neighbors=NUM_NEIGHBORS,
)

response

[[MatchNeighbor(id='1264', distance=0.3265237808227539),
  MatchNeighbor(id='2314', distance=0.3303952217102051),
  MatchNeighbor(id='896', distance=0.33482474088668823),
  MatchNeighbor(id='1949', distance=0.3377643823623657),
  MatchNeighbor(id='2035', distance=0.34422242641448975),
  MatchNeighbor(id='2217', distance=0.34488797187805176),
  MatchNeighbor(id='1539', distance=0.3458418846130371),
  MatchNeighbor(id='1568', distance=0.3460204601287842),
  MatchNeighbor(id='2106', distance=0.3461019992828369),
  MatchNeighbor(id='603', distance=0.34620773792266846)]]

In [46]:
matched_ids = [int(match.id) for match in response[0]]
matched_distances = [match.distance for match in response[0]]
matched_titles = [metadata.title[i] for i in matched_ids]
matched_abstracts = [metadata.abstract[i] for i in matched_ids]
matched_urls = [metadata.url[i] for i in matched_ids]

matches = pd.DataFrame({
    "distance": matched_distances,
    "title": matched_titles,
    "abstract": matched_abstracts,
    "url": matched_urls
})
matches

Unnamed: 0,distance,title,abstract,url
0,0.326524,MEDIDAS DE BIOSSEGURANÇA PARA ENFRENTAMENTO AO...,Introdução A doença coronavírus 2019 (COVID-19...,https://api.elsevier.com/content/article/pii/S...
1,0.330395,Prevention of viral infections in solid organ ...,INTRODUCTION In solid organ transplant (SOT) r...,https://www.ncbi.nlm.nih.gov/pubmed/34854329/;...
2,0.334825,Psychological aspects of pain prevention,"How to prevent the onset, maintenance, or exac...",https://www.ncbi.nlm.nih.gov/pubmed/33977186/;...
3,0.337764,Protective measures against COVID-19 and the b...,The present study sought to develop a conceptu...,https://www.sciencedirect.com/science/article/...
4,0.344222,"Interplay between risk perception, behaviour, ...",Pharmaceutical and non-pharmaceutical interven...,https://arxiv.org/pdf/2112.12062v2.pdf; https:...
5,0.344888,Change in Pediatric Health Care Spending and D...,Objective: To evaluate how the restrictive mea...,https://www.ncbi.nlm.nih.gov/pubmed/34943379/;...
6,0.345842,Severe acute respiratory syndrome coronavirus-...,Understanding infections related to handling h...,https://www.ncbi.nlm.nih.gov/pubmed/33657925/;...
7,0.34602,The Importance of Understanding COVID-19: The ...,Background: Past research suggests that knowle...,https://www.ncbi.nlm.nih.gov/pubmed/33889557/;...
8,0.346102,Lesson learned from the pandemic: Isolation an...,INTRODUCTION It was previously demonstrated th...,https://doi.org/10.1177/10781552211043836; htt...
9,0.346208,Mosques in Japan responding to COVID-19 pandem...,Religious activities tend to be conducted in e...,https://api.elsevier.com/content/article/pii/S...


In [47]:
html = "<html><body><ol>"
for i in range(len(matches)):
    html += f"""            
    <li> 
        <article>
            <header>
                <a href="{matches.url[i]}"> <h2>{matches.title[i]}</h2></a>
            </header>
            <p>{matches.abstract[i]}</p>
        </article>
    </li>
    """
html += "</body></html>"
display.HTML(html)

## Cleaning Up

In [None]:
matching_engine.delete(force=True)
matching_engine_index.delete()

Copyright 2023 Google Inc.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.