<a href="https://colab.research.google.com/github/YoshiyukiKono/semantic-text-search/blob/main/semantic_text_search-en.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Semantic Text Search by Astra DB Vector Search

## Prerequisites

### Astra DB

1. Create a new ***vector search enabled database*** in [Astra](https://astra.datastax.com/).
1. Create a keyspace (`semantics`) (You can change the name of kayspace in the code block below to use the one you adopted.)
1. Get an application token

We will create a table and an index in this walkthrough.


### Google Colab

It is expected to use GPU when using the embedding model.
Change the setting from the menu as follows:
Runtime > Change runtime type > Hardware accelerator: `GPU`

## Data Set and Sentence Transformers

First, we will see the process to prepare the data set used for this demo and the tool to embedded the data into vectors.

To begin we must install the required prerequisite libraries:


In [None]:
!pip install -U \
  datasets==2.12.0 \
  sentence-transformers==2.2.2



### Data Preprocessing
The dataset preparation process requires a few steps:

1. We download the Quora dataset from Hugging Face Datasets.
2. The text content of the dataset is embedded into vectors.



In [None]:
from datasets import load_dataset

dataset = load_dataset('quora', split='train[240000:320000]')
dataset

Downloading builder script:   0%|          | 0.00/2.38k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/1.13k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/5.69k [00:00<?, ?B/s]

Downloading and preparing dataset quora/default to /root/.cache/huggingface/datasets/quora/default/0.0.0/36ba4cd42107f051a158016f1bea6ae3f4685c5df843529108a54e42d86c1e04...


Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/58.2M [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Generating train split:   0%|          | 0/404290 [00:00<?, ? examples/s]

Dataset quora downloaded and prepared to /root/.cache/huggingface/datasets/quora/default/0.0.0/36ba4cd42107f051a158016f1bea6ae3f4685c5df843529108a54e42d86c1e04. Subsequent calls will reuse this data.


Dataset({
    features: ['questions', 'is_duplicate'],
    num_rows: 80000
})

The dataset contains ~400K pairs of natural language questions from Quora.

In [None]:
dataset[:5]

{'questions': [{'id': [207550, 351729],
   'text': ['What is the truth of life?', "What's the evil truth of life?"]},
  {'id': [33183, 351730],
   'text': ['Which is the best smartphone under 20K in India?',
    'Which is the best smartphone with in 20k in India?']},
  {'id': [351731, 351732],
   'text': ['Steps taken by Canadian government to improve literacy rate?',
    'Can I send homemade herbal hair oil from India to US via postal or private courier services?']},
  {'id': [37799, 94186],
   'text': ['What is a good way to lose 30 pounds in 2 months?',
    'What can I do to lose 30 pounds in 2 months?']},
  {'id': [351733, 351734],
   'text': ['Which of the following most accurately describes the translation of the graph y = (x+3)^2 -2 to the graph of y = (x -2)^2 +2?',
    'How do you graph x + 2y = -2?']}],
 'is_duplicate': [False, True, False, True, False]}

Whether or not the questions are duplicates is not so important, all we need for this example is the text itself. We can extract them all into a single questions list.

In [None]:
questions = []

for record in dataset['questions']:
    questions.extend(record['text'])

# remove duplicates
questions = list(set(questions))
print('\n'.join(questions[:5]))
print(len(questions))

Almost overnight a couple of my upper molar teeth on both sides have started to feel pointy and hurt my tongue when I speak. What should I do?
One of my friend submitted fake IT proofs to save tax. If he was found by the IT department, what will be the consequences?
Inbound Marketing: Hubspot: Have any small, specialized high tech consulting companies had success with Hubspot?
What type of government does Greece have? How effective has this government been?
Is "No Man's Sky" considered a failure?
136057


### Building Embeddings

To create our embeddings we will us the `MiniLM-L6` sentence transformer model. This is a very efficient semantic similarity embedding model from the sentence-transformers library. We initialize it like so:

In [None]:
from sentence_transformers import SentenceTransformer
import torch

device = 'cuda' if torch.cuda.is_available() else 'cpu'
if device != 'cuda':
    print(f"You are using {device}. This is much slower than using "
          "a CUDA-enabled GPU. If on Colab you can change this by "
          "clicking Runtime > Change runtime type > GPU.")

model = SentenceTransformer('all-MiniLM-L6-v2', device=device)
model

Downloading (…)e9125/.gitattributes:   0%|          | 0.00/1.18k [00:00<?, ?B/s]

Downloading (…)_Pooling/config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

Downloading (…)7e55de9125/README.md:   0%|          | 0.00/10.6k [00:00<?, ?B/s]

Downloading (…)55de9125/config.json:   0%|          | 0.00/612 [00:00<?, ?B/s]

Downloading (…)ce_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

Downloading (…)125/data_config.json:   0%|          | 0.00/39.3k [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

Downloading (…)nce_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

Downloading (…)e9125/tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/350 [00:00<?, ?B/s]

Downloading (…)9125/train_script.py:   0%|          | 0.00/13.2k [00:00<?, ?B/s]

Downloading (…)7e55de9125/vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading (…)5de9125/modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

SentenceTransformer(
  (0): Transformer({'max_seq_length': 256, 'do_lower_case': False}) with Transformer model: BertModel 
  (1): Pooling({'word_embedding_dimension': 384, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False})
  (2): Normalize()
)

There are three interesting bits of information in the above model printout. Those are:

 - `max_seq_length` is `256`. That means that the maximum number of tokens (like words) that can be encoded into a single vector embedding is `256`. Anything beyond this must be truncated.

 - `word_embedding_dimension` is `384`. This number is the dimensionality of vectors output by this model. It is important that we know this number later when registering this data set into our Astra DB vector-enabled database.

 - `Normalize()`. This final normalization step indicates that all vectors produced by the model are normalized. That means that models that we would typical measure similarity for using cosine similarity can also make use of the dotproduct similarity metric. In fact, with normalized vectors cosine and dotproduct are equivalent.

Moving on, we can create a sentence embedding using this model like so:

In [None]:
query = 'which city is the most populated in the world?'

xq = model.encode(query)
xq.shape

(384,)

By using `xq.shape`, it is possible to check how many dimensions the query string was converted to. You can see that it has been converted to a 384-dimensional vector. Even when you convert another query string with different lengths as shown below, the vector after conversion will have the same number of dimensions.

In [None]:
query = 'Is it true that the coordinaate of a point on x-axis can be taken as (y,0) while on y-axis it can be taken as (0,x)?'

xq = model.encode(query)
xq.shape

(384,)

We will use this model to embed all questions when upserting them to Astra DB.

In [None]:
def get_embeddings(text):
  return model.encode(text).tolist()

## Astra DB Connection

### Cassandra Driver Install

In [None]:
!pip install cassandra-driver

Collecting cassandra-driver
  Downloading cassandra_driver-3.28.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (19.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m19.1/19.1 MB[0m [31m52.1 MB/s[0m eta [36m0:00:00[0m
Collecting geomet<0.3,>=0.1 (from cassandra-driver)
  Downloading geomet-0.2.1.post1-py3-none-any.whl (18 kB)
Installing collected packages: geomet, cassandra-driver
Successfully installed cassandra-driver-3.28.0 geomet-0.2.1.post1


In [None]:
import cassandra; print (cassandra.__version__)

3.28.0


### Astra DB Secutiry Settings

Place the Connect Bundle file in the execution environment. Upload the file downloaded from the Astra control plane from the left side "Files" menu of Colab.

Or, if you like, you can download your Connect Bundle file directory from Astra to your Colab environment (**please modify the cell below**), but note that the URL that you find on your Astra environment is not static, so you would need to copy the URL again when you will run this demo in another Colab session at a later date.

In [None]:
!wget -O secure-connect-demo.zip "https://datastax..."

--2023-07-27 02:19:31--  https://datastax-cluster-config-prod.s3.us-east-2.amazonaws.com/d5556151-ea9a-4309-8be3-b8ea2b1cd03d-1/secure-connect-demo.zip?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIA2AIQRQ76S2JCB77W%2F20230727%2Fus-east-2%2Fs3%2Faws4_request&X-Amz-Date=20230727T021921Z&X-Amz-Expires=300&X-Amz-SignedHeaders=host&X-Amz-Signature=e064e579dac30f9be1a74dcfea2e8ac701a758d4baa3ff5b5eeb6b02afe86db5
Resolving datastax-cluster-config-prod.s3.us-east-2.amazonaws.com (datastax-cluster-config-prod.s3.us-east-2.amazonaws.com)... 52.219.80.64, 52.219.177.66, 52.219.178.226, ...
Connecting to datastax-cluster-config-prod.s3.us-east-2.amazonaws.com (datastax-cluster-config-prod.s3.us-east-2.amazonaws.com)|52.219.80.64|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 12247 (12K) [application/zip]
Saving to: ‘secure-connect-demo.zip’


2023-07-27 02:19:32 (157 MB/s) - ‘secure-connect-demo.zip’ saved [12247/12247]



Modify the following variables to access your environment.

In [None]:
SECURE_CONNECT_BUNDLE_PATH = 'secure-connect-demo.zip'

In [None]:
import getpass

ASTRA_CLIENT_ID = getpass.getpass()

··········


In [None]:
ASTRA_CLIENT_SECRET = getpass.getpass()

··········


You don't need to execute the block below, but you can use it for checking if there is a problem with the subsequent connection to the database.

In [None]:
print('ASTRA_CLIENT_ID:[' + ASTRA_CLIENT_ID + ']')
print('ASTRA_CLIENT_SECRET:[' + ASTRA_CLIENT_SECRET + ']')

### Connection

In [None]:
from cassandra.cluster import Cluster
from cassandra.auth import PlainTextAuthProvider

cloud_config= {
  'secure_connect_bundle': SECURE_CONNECT_BUNDLE_PATH
}
auth_provider = PlainTextAuthProvider(ASTRA_CLIENT_ID, ASTRA_CLIENT_SECRET)
cluster = Cluster(cloud=cloud_config, auth_provider=auth_provider)
session = cluster.connect()

row = session.execute("select release_version from system.local").one()
if row:
  print(row[0])
else:
  print("An error occurred.")

ERROR:cassandra.connection:Closing connection <AsyncoreConnection(132838110202656) d5556151-ea9a-4309-8be3-b8ea2b1cd03d-us-east1.db.astra.datastax.com:29042:5903f2b2-4bfe-4035-9c52-a19adae6e381> due to protocol error: Error from server: code=000a [Protocol error] message="Beta version of the protocol used (5/v5-beta), but USE_BETA flag is unset"


4.0.7-a81def0a9e90


Keyspace definition and existence check

In [None]:
YOUR_KEYSPACE = 'semantics'

In [None]:
session.set_keyspace(YOUR_KEYSPACE)
session

<cassandra.cluster.Session at 0x78d0c00ff6a0>

## Vector Search powered by Astra DB

### Environment Preparation

We will create a table and an index for the demo.

In [None]:
session.execute(f"""CREATE TABLE IF NOT EXISTS {YOUR_KEYSPACE}.questions
(id uuid,
 question text,
 question_embedding vector<float, 384>,

 PRIMARY KEY (id))""")

<cassandra.cluster.ResultSet at 0x78d1c86b9ea0>

In [None]:
session.execute(f"""CREATE CUSTOM INDEX IF NOT EXISTS vector_search_index
   ON {YOUR_KEYSPACE}.questions (question_embedding)
   USING 'org.apache.cassandra.index.sai.StorageAttachedIndex'
   WITH OPTIONS = {{ 'similarity_function': 'dot_product' }}""")

<cassandra.cluster.ResultSet at 0x78d0ca12b220>

Before registering the demo data set, let's just check the created table and index using a sample record.

In [None]:
question = 'Is it true that the coordinaate of a point on x-axis can be taken as (y,0) while on y-axis it can be taken as (0,x)?'
embedding = get_embeddings(question)
embedding

[0.05344191938638687,
 -0.030281946063041687,
 -0.06466147303581238,
 -0.12616895139217377,
 -0.02430957742035389,
 0.030188502743840218,
 -0.03165227174758911,
 -0.044692106544971466,
 0.0951315313577652,
 -0.005873900838196278,
 0.14823749661445618,
 0.10260108858346939,
 0.003760535968467593,
 0.0680219978094101,
 0.03177761286497116,
 -0.05901845172047615,
 -0.03500355780124664,
 -0.07596974074840546,
 0.038683172315359116,
 0.056720227003097534,
 0.07208768278360367,
 -0.03127380833029747,
 -0.012993257492780685,
 0.03162640705704689,
 0.07051978260278702,
 -0.048730675131082535,
 0.09373128414154053,
 0.026617592200636864,
 0.017795585095882416,
 0.02070174552500248,
 -0.04620673134922981,
 -0.03261915594339371,
 -0.06299709528684616,
 0.0015578059246763587,
 -0.004674585070461035,
 -0.006311411038041115,
 0.1345195472240448,
 -0.03287360444664955,
 0.029681209474802017,
 0.06412471830844879,
 0.03275110572576523,
 0.024843985214829445,
 0.055040955543518066,
 0.00878546107560396

Use this piece of data to verify that the table and index defined above are working correctly. (In the subsequent process, we will check the  power of vector search after registering a large amount of data in the database.)

In [None]:
from cassandra.query import SimpleStatement
query = SimpleStatement(
                f"""
                INSERT INTO {YOUR_KEYSPACE}.questions
                (id, question, question_embedding)
                VALUES (now(), %s, %s)
                """
            )
session.execute(query,(question, embedding))

<cassandra.cluster.ResultSet at 0x78d0c00fdc60>

In [None]:
query = SimpleStatement(
    f"""
    SELECT id, question, question_embedding
    FROM {YOUR_KEYSPACE}.questions
    ORDER BY question_embedding ANN OF {embedding} LIMIT 5;
    """
    )

In [None]:
results = session.execute(query)
top_5_products = results._current_rows

for row in top_5_products:
  print(f"""{row.id}, {row.question}, {row.question_embedding}\n""")

977df310-2c25-11ee-9e0e-0d8e1043dee0, Is it true that the coordinaate of a point on x-axis can be taken as (y,0) while on y-axis it can be taken as (0,x)?, [0.05344191938638687, -0.030281946063041687, -0.06466147303581238, -0.12616895139217377, -0.02430957742035389, 0.030188502743840218, -0.03165227174758911, -0.044692106544971466, 0.0951315313577652, -0.005873900838196278, 0.14823749661445618, 0.10260108858346939, 0.003760535968467593, 0.0680219978094101, 0.03177761286497116, -0.05901845172047615, -0.03500355780124664, -0.07596974074840546, 0.038683172315359116, 0.056720227003097534, 0.07208768278360367, -0.03127380833029747, -0.012993257492780685, 0.03162640705704689, 0.07051978260278702, -0.048730675131082535, 0.09373128414154053, 0.026617592200636864, 0.017795585095882416, 0.02070174552500248, -0.04620673134922981, -0.03261915594339371, -0.06299709528684616, 0.0015578059246763587, -0.004674585070461035, -0.006311411038041115, 0.1345195472240448, -0.03287360444664955, 0.029681209474

### Data Registration

**PLEASE NOTE:** Please ensure to use GPU. When using CPU, running the following cell should take a couple of hours. If you really want to shorten the time, please slice the `questions` list like `questions[:N]`, but it'd be the point to use the certain amount of data (similar to the Pinecone sample) in order to show the power of Astra DB vector search

In [None]:
from tqdm.auto import tqdm

batch_size = 128
embedding_params = []
for i in tqdm(range(0, len(questions), batch_size)):
    i_end = min(i+batch_size, len(questions))
    embedding_params.extend(model.encode(questions[i:i_end]))

  0%|          | 0/1063 [00:00<?, ?it/s]

In [None]:
print(len(embedding_params))
print(len(questions))

136057
136057


In [None]:
params_list = []
for i in tqdm(range(0, len(questions))):
  params_list.append((questions[i], embedding_params[i]))

  0%|          | 0/136057 [00:00<?, ?it/s]

In [None]:
print(len(params_list))

136057


The block below takes about 5 minutes to run. If you want to shorten it, please adjust the number of `questions` (thus `params_list`) in the upper cell.

In [None]:
from cassandra.concurrent import execute_concurrent_with_args
request = session.prepare(
                    f"""
                INSERT INTO {YOUR_KEYSPACE}.questions
                (id, question, question_embedding)
                VALUES (now(), ?, ?)
                """
)
execute_concurrent_with_args(session, request, params_list)

[ExecutionResult(success=True, result_or_exc=<cassandra.cluster.ResultSet object at 0x78d09a3622f0>),
 ExecutionResult(success=True, result_or_exc=<cassandra.cluster.ResultSet object at 0x78d09a3631f0>),
 ExecutionResult(success=True, result_or_exc=<cassandra.cluster.ResultSet object at 0x78d09a363d60>),
 ExecutionResult(success=True, result_or_exc=<cassandra.cluster.ResultSet object at 0x78d09a362500>),
 ExecutionResult(success=True, result_or_exc=<cassandra.cluster.ResultSet object at 0x78d09a363b80>),
 ExecutionResult(success=True, result_or_exc=<cassandra.cluster.ResultSet object at 0x78d09a384310>),
 ExecutionResult(success=True, result_or_exc=<cassandra.cluster.ResultSet object at 0x78d09a362890>),
 ExecutionResult(success=True, result_or_exc=<cassandra.cluster.ResultSet object at 0x78d09a3639a0>),
 ExecutionResult(success=True, result_or_exc=<cassandra.cluster.ResultSet object at 0x78d09a384130>),
 ExecutionResult(success=True, result_or_exc=<cassandra.cluster.ResultSet object a

Finally, you should see 136057 rows in your table as follows. (Results may vary depending on previously registered data)
```
token@cqlsh:demo> select count(*) FROM questions;

 count
--------
 136058

(1 rows)
```

### Vector Search Demo

In [None]:
question = 'How do I promote my e-commerce website?'
embedding = get_embeddings(question)
embedding

[0.006989829242229462,
 -0.04213974252343178,
 -0.07815290242433548,
 -0.009970204904675484,
 0.07739722728729248,
 0.04193941876292229,
 -0.006748862564563751,
 0.030799392610788345,
 -0.07842331379652023,
 -0.03257639706134796,
 0.05655832588672638,
 -0.03520314022898674,
 0.07220868766307831,
 0.028285248205065727,
 0.07695521414279938,
 -0.05429329350590706,
 0.009619899094104767,
 0.05280487611889839,
 -0.01875201240181923,
 -0.11907956749200821,
 0.014776456169784069,
 -0.00023785493976902217,
 0.06035557761788368,
 0.0033046926837414503,
 -0.06253878027200699,
 -0.04962966963648796,
 0.007044524420052767,
 0.05366762354969978,
 -0.008639861829578876,
 -0.11296573281288147,
 0.04041944444179535,
 -0.07675006240606308,
 0.08271502703428268,
 0.04834199696779251,
 -0.02313019335269928,
 0.013846631161868572,
 -0.03671257942914963,
 -0.10863407701253891,
 -0.04204464331269264,
 0.025294026359915733,
 -0.009148983284831047,
 -0.09623943269252777,
 -0.048046406358480453,
 0.0587533973

In [None]:
query = SimpleStatement(
    f"""
    SELECT id, question, question_embedding
    FROM {YOUR_KEYSPACE}.questions
    ORDER BY question_embedding ANN OF {embedding} LIMIT 5;
    """
    )

In [None]:
results = session.execute(query)
top_5_products = results._current_rows

for row in top_5_products:
  print(f"""{row.id}, {row.question}, {row.question_embedding}\n""")

8b3638f0-2c26-11ee-bc57-b9b25387eb5b, How do I promote my e-commerce website?, [0.006989842280745506, -0.04213970899581909, -0.07815290987491608, -0.009970256127417088, 0.0773971900343895, 0.04193941131234169, -0.006748853251338005, 0.030799392610788345, -0.07842329144477844, -0.03257638216018677, 0.056558333337306976, -0.035203125327825546, 0.07220874726772308, 0.02828521840274334, 0.07695518434047699, -0.05429328605532646, 0.009619822725653648, 0.05280487239360809, -0.01875201053917408, -0.11907956749200821, 0.014776422642171383, -0.00023787017562426627, 0.06035558879375458, 0.0033046663738787174, -0.0625387504696846, -0.04962966963648796, 0.007044512778520584, 0.05366763845086098, -0.00863985251635313, -0.11296575516462326, 0.040419455617666245, -0.07675006985664368, 0.08271501958370209, 0.04834200441837311, -0.02313019521534443, 0.01384664699435234, -0.036712534725666046, -0.10863406211137772, -0.04204464331269264, 0.025294041261076927, -0.009149019606411457, -0.09623944759368896, 

## Cleanup

In [None]:
session.execute(f"""DROP INDEX IF EXISTS {YOUR_KEYSPACE}.vector_search_index""")

<cassandra.cluster.ResultSet at 0x78d1b41dd5a0>

In [None]:
session.execute(f"""DROP TABLE IF EXISTS {YOUR_KEYSPACE}.questions""")

<cassandra.cluster.ResultSet at 0x78d0c00dcee0>