# Text Semantic Search with AutoMM

## 1. Introduction to semantic embedding

Semantic embedding is one of the main workhorses behind the modern search technology. Instead of directly matching the query to candidates by term frequency (e.g., BM25), a semantic search algorithm matches them by first converting the text $x$ into a feature vector $\phi(x)$ then comparing the similarities using a distance metric defined in that vector space. These feature vectors, known as a "vector embedding", are often trained end-to-end on large text corpus, so that they encode the *semantic* meaning of the text. For example, synonyms are embedded to a similar region of the vector space and relationships between words are often revealed by algebraic operations (see Figure 1 for an example). For these reasons, a vector embedding of text are also known as a **semantic embedding**. With a semantic embedding of the query and the search candidate documents, a search algorithm can often be reduced to finding most similar vectors. This new approach to search is known as **semantic search**.

![Similar sentences have similar embeddings. Image from [Medium](https://medium.com/towards-data-science/fine-grained-analysis-of-sentence-embeddings-a3ff0a42cce5)](https://miro.medium.com/max/1400/0*esMqhzu9WhLiP3bD.jpg)


There are three main advantages of using semantic embeddings for a search problem over classical information-retrieval methods (e.g., bag-of-words or TF/IDF).  First, it returns candidates that are related according to the meaning of the text, rather than similar word usage.  This helps to discover paraphrased text and similar concepts described in very different ways. Secondly, semantic search is often more computationally efficient. Vector embeddings of the candidates can be pre-computed and stored in data structures. Highly scalable sketching techniques such as locality-sensitive hashing (LSH) and max-inner product search (MIPS) are available for efficiently finding similar vectors in the embedding space. Last but not least, the semantic embedding approach allows us to straightforwardly generalize the same search algorithm beyond text, such as multi-modality search. For example, can we use a text query to search for images without textual annotations?  Can we search for a website using an image query?  With semantic search, one can simply use the most appropriate vector embedding of these multi-modal objects and jointly train the embeddings using datasets with both text and images.

This tutorial provides you a gentle entry point in deploying AutoMM to semantic search.

In [None]:
!pip install autogluon.multimodal

Collecting autogluon.multimodal
  Downloading autogluon.multimodal-1.1.1-py3-none-any.whl.metadata (12 kB)
Collecting scipy<1.13,>=1.5.4 (from autogluon.multimodal)
  Downloading scipy-1.12.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (60 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m60.4/60.4 kB[0m [31m1.2 MB/s[0m eta [36m0:00:00[0m
Collecting boto3<2,>=1.10 (from autogluon.multimodal)
  Downloading boto3-1.35.22-py3-none-any.whl.metadata (6.6 kB)
Collecting torch<2.4,>=2.2 (from autogluon.multimodal)
  Downloading torch-2.3.1-cp310-cp310-manylinux1_x86_64.whl.metadata (26 kB)
Collecting lightning<2.4,>=2.2 (from autogluon.multimodal)
  Downloading lightning-2.3.3-py3-none-any.whl.metadata (35 kB)
Collecting transformers<4.41.0,>=4.38.0 (from transformers[sentencepiece]<4.41.0,>=4.38.0->autogluon.multimodal)
  Downloading transformers-4.40.2-py3-none-any.whl.metadata (137 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m

In [None]:
%%capture
!pip3 install ir_datasets
import ir_datasets
import pandas as pd
pd.set_option('display.max_colwidth', None)

In [None]:
%%capture
dataset = ir_datasets.load("beir/nfcorpus/test")

# prepare dataset
doc_data = pd.DataFrame(dataset.docs_iter())
query_data = pd.DataFrame(dataset.queries_iter())
labeled_data = pd.DataFrame(dataset.qrels_iter())
label_col = "relevance"
query_id_col = "query_id"
doc_id_col = "doc_id"
text_col = "text"
id_mappings={query_id_col: query_data.set_index(query_id_col)[text_col], doc_id_col: doc_data.set_index(doc_id_col)[text_col]}

In [None]:
labeled_data.head()

Unnamed: 0,query_id,doc_id,relevance,iteration
0,PLAIN-2,MED-2427,2,0
1,PLAIN-2,MED-10,2,0
2,PLAIN-2,MED-2429,2,0
3,PLAIN-2,MED-2430,2,0
4,PLAIN-2,MED-2431,2,0


In [None]:
query_data.head()

Unnamed: 0,query_id,text,url
0,PLAIN-2,Do Cholesterol Statin Drugs Cause Breast Cancer?,http://nutritionfacts.org/2015/07/16/do-cholesterol-statin-drugs-cause-breast-cancer/
1,PLAIN-12,Exploiting Autophagy to Live Longer,http://nutritionfacts.org/2015/06/11/exploiting-autophagy-to-live-longer/
2,PLAIN-23,How to Reduce Exposure to Alkylphenols Through Your Diet,http://nutritionfacts.org/2015/04/28/how-to-reduce-exposure-to-alkylphenols-through-your-diet/
3,PLAIN-33,What’s Driving America’s Obesity Problem?,http://nutritionfacts.org/2015/03/24/whats-driving-americas-obesity-problem/
4,PLAIN-44,Who Should be Careful About Curcumin?,http://nutritionfacts.org/2015/02/12/who-should-be-careful-about-curcumin/


In [None]:
query_data = query_data.drop("url", axis=1)
query_data.head()

Unnamed: 0,query_id,text
0,PLAIN-2,Do Cholesterol Statin Drugs Cause Breast Cancer?
1,PLAIN-12,Exploiting Autophagy to Live Longer
2,PLAIN-23,How to Reduce Exposure to Alkylphenols Through Your Diet
3,PLAIN-33,What’s Driving America’s Obesity Problem?
4,PLAIN-44,Who Should be Careful About Curcumin?


In [None]:
doc_data.head(1)

Unnamed: 0,doc_id,text,title,url
0,MED-10,"Recent studies have suggested that statins, an established drug group in the prevention of cardiovascular mortality, could delay or prevent breast cancer recurrence but the effect on disease-specific mortality remains unclear. We evaluated risk of breast cancer death among statin users in a population-based cohort of breast cancer patients. The study cohort included all newly diagnosed breast cancer patients in Finland during 1995–2003 (31,236 cases), identified from the Finnish Cancer Registry. Information on statin use before and after the diagnosis was obtained from a national prescription database. We used the Cox proportional hazards regression method to estimate mortality among statin users with statin use as time-dependent variable. A total of 4,151 participants had used statins. During the median follow-up of 3.25 years after the diagnosis (range 0.08–9.0 years) 6,011 participants died, of which 3,619 (60.2%) was due to breast cancer. After adjustment for age, tumor characteristics, and treatment selection, both post-diagnostic and pre-diagnostic statin use were associated with lowered risk of breast cancer death (HR 0.46, 95% CI 0.38–0.55 and HR 0.54, 95% CI 0.44–0.67, respectively). The risk decrease by post-diagnostic statin use was likely affected by healthy adherer bias; that is, the greater likelihood of dying cancer patients to discontinue statin use as the association was not clearly dose-dependent and observed already at low-dose/short-term use. The dose- and time-dependence of the survival benefit among pre-diagnostic statin users suggests a possible causal effect that should be evaluated further in a clinical trial testing statins’ effect on survival in breast cancer patients.",Statin Use and Breast Cancer Survival: A Nationwide Cohort Study from Finland,http://www.ncbi.nlm.nih.gov/pubmed/25329299


In [None]:
doc_data[text_col] = doc_data[[text_col, "title"]].apply(" ".join, axis=1)
doc_data = doc_data.drop(["title", "url"], axis=1)
doc_data.head(1)

Unnamed: 0,doc_id,text
0,MED-10,"Recent studies have suggested that statins, an established drug group in the prevention of cardiovascular mortality, could delay or prevent breast cancer recurrence but the effect on disease-specific mortality remains unclear. We evaluated risk of breast cancer death among statin users in a population-based cohort of breast cancer patients. The study cohort included all newly diagnosed breast cancer patients in Finland during 1995–2003 (31,236 cases), identified from the Finnish Cancer Registry. Information on statin use before and after the diagnosis was obtained from a national prescription database. We used the Cox proportional hazards regression method to estimate mortality among statin users with statin use as time-dependent variable. A total of 4,151 participants had used statins. During the median follow-up of 3.25 years after the diagnosis (range 0.08–9.0 years) 6,011 participants died, of which 3,619 (60.2%) was due to breast cancer. After adjustment for age, tumor characteristics, and treatment selection, both post-diagnostic and pre-diagnostic statin use were associated with lowered risk of breast cancer death (HR 0.46, 95% CI 0.38–0.55 and HR 0.54, 95% CI 0.44–0.67, respectively). The risk decrease by post-diagnostic statin use was likely affected by healthy adherer bias; that is, the greater likelihood of dying cancer patients to discontinue statin use as the association was not clearly dose-dependent and observed already at low-dose/short-term use. The dose- and time-dependence of the survival benefit among pre-diagnostic statin users suggests a possible causal effect that should be evaluated further in a clinical trial testing statins’ effect on survival in breast cancer patients. Statin Use and Breast Cancer Survival: A Nationwide Cohort Study from Finland"


In [None]:
!pip install scikit-learn




In [None]:
from sklearn.metrics import ndcg_score


In [None]:
from sklearn.metrics import ndcg_score
import numpy as np

def evaluate_bm25(doc_data, query_data, qrel_dict, cutoffs):
    tokenized_corpus = tokenize_corpus(doc_data[text_col].tolist())
    bm25_model = BM25Okapi(tokenized_corpus, k1=1.2, b=0.75)

    results = rank_documents_bm25(query_data[text_col].tolist(), query_data[query_id_col].tolist(), doc_data[doc_id_col].tolist(), max(cutoffs), bm25_model)

    # Prepare relevance and prediction arrays
    y_true = []
    y_score = []

    for qid in results:
        true_relevance = []
        predicted_scores = []
        for doc_id, score in results[qid].items():
            true_relevance.append(qrel_dict[qid].get(doc_id, 0))  # Default to 0 if not relevant
            predicted_scores.append(score)

        y_true.append(true_relevance)
        y_score.append(predicted_scores)

    # Compute NDCG
    ndcg_scores = ndcg_score(np.array(y_true), np.array(y_score), k=max(cutoffs))

    return ndcg_scores


In [None]:
qrel_dict = get_qrels(dataset)
ndcg = evaluate_bm25(doc_data, query_data, qrel_dict, cutoffs)
print(f"NDCG Score: {ndcg}")


NDCG Score: 0.5223977672824955


In [None]:
!pip uninstall -y torch torchvision torchaudio
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu

Found existing installation: torch 2.3.1
Uninstalling torch-2.3.1:
  Successfully uninstalled torch-2.3.1
Found existing installation: torchvision 0.18.1
Uninstalling torchvision-0.18.1:
  Successfully uninstalled torchvision-0.18.1
Found existing installation: torchaudio 2.4.0+cu121
Uninstalling torchaudio-2.4.0+cu121:
  Successfully uninstalled torchaudio-2.4.0+cu121
Looking in indexes: https://download.pytorch.org/whl/cpu
Collecting torch
  Downloading https://download.pytorch.org/whl/cpu/torch-2.4.1%2Bcpu-cp310-cp310-linux_x86_64.whl (194.9 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m194.9/194.9 MB[0m [31m6.8 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting torchvision
  Downloading https://download.pytorch.org/whl/cpu/torchvision-0.19.1%2Bcpu-cp310-cp310-linux_x86_64.whl (1.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.6/1.6 MB[0m [31m58.1 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting torchaudio
  Downloading https://download.py

In [None]:
%%capture
from autogluon.multimodal import MultiModalPredictor

predictor = MultiModalPredictor(
        query=query_id_col,
        response=doc_id_col,
        label=label_col,
        problem_type="text_similarity",
        hyperparameters={"model.hf_text.checkpoint_name": "sentence-transformers/all-MiniLM-L6-v2"}
    )

In [None]:
predictor.evaluate(
        labeled_data,
        query_data=query_data[[query_id_col]],
        response_data=doc_data[[doc_id_col]],
        id_mappings=id_mappings,
        cutoffs=cutoffs,
        metrics=["ndcg"],
    )

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/612 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/350 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

  self.pid = os.fork()


Predicting: |          | 0/? [00:00<?, ?it/s]

  self.pid = os.fork()
  return bound(*args, **kwds)
  self.pid = os.fork()


Predicting: |          | 0/? [00:00<?, ?it/s]

  self.pid = os.fork()


Predicting: |          | 0/? [00:00<?, ?it/s]

  self.pid = os.fork()


Predicting: |          | 0/? [00:00<?, ?it/s]

{'ndcg@5': 0.33672, 'ndcg@10': 0.30893, 'ndcg@20': 0.28197}

In [None]:
from autogluon.multimodal.utils import semantic_search
hits = semantic_search(
        matcher=predictor,
        query_data=query_data[text_col].tolist(),
        response_data=doc_data[text_col].tolist(),
        query_chunk_size=len(query_data),
        top_k=max(cutoffs),
    )

  self.pid = os.fork()


Predicting: |          | 0/? [00:00<?, ?it/s]

  self.pid = os.fork()


Predicting: |          | 0/? [00:00<?, ?it/s]

  self.pid = os.fork()


In [None]:
query_embeds = predictor.extract_embedding(query_data[[query_id_col]], id_mappings=id_mappings, as_tensor=True)
doc_embeds = predictor.extract_embedding(doc_data[[doc_id_col]], id_mappings=id_mappings, as_tensor=True)

  self.pid = os.fork()


Predicting: |          | 0/? [00:00<?, ?it/s]

  self.pid = os.fork()
  self.pid = os.fork()


Predicting: |          | 0/? [00:00<?, ?it/s]

  self.pid = os.fork()


In [None]:
import torch
from autogluon.multimodal.utils import compute_semantic_similarity

def hybridBM25(query_data, query_embeds, doc_data, doc_embeds, recall_num, top_k, beta):
    # Recall documents with BM25 scores
    tokenized_corpus = tokenize_corpus(doc_data[text_col].tolist())
    bm25_model = BM25Okapi(tokenized_corpus, k1=1.2, b=0.75)
    bm25_scores = rank_documents_bm25(query_data[text_col].tolist(), query_data[query_id_col].tolist(), doc_data[doc_id_col].tolist(), recall_num, bm25_model)

    all_bm25_scores = [score for scores in bm25_scores.values() for score in scores.values()]
    max_bm25_score = max(all_bm25_scores)
    min_bm25_score = min(all_bm25_scores)

    q_embeddings = {qid: embed for qid, embed in zip(query_data[query_id_col].tolist(), query_embeds)}
    d_embeddings = {did: embed for did, embed in zip(doc_data[doc_id_col].tolist(), doc_embeds)}

    query_ids = query_data[query_id_col].tolist()
    results = {qid: {} for qid in query_ids}
    for idx, qid in enumerate(query_ids):
        rec_docs = bm25_scores[qid]
        rec_doc_emb = [d_embeddings[doc_id] for doc_id in rec_docs.keys()]
        rec_doc_id = [doc_id for doc_id in rec_docs.keys()]
        rec_doc_emb = torch.stack(rec_doc_emb)
        scores = compute_semantic_similarity(q_embeddings[qid], rec_doc_emb)
        scores[torch.isnan(scores)] = -1
        top_k_values, top_k_idxs = torch.topk(
            scores,
            min(top_k + 1, len(scores[0])),
            dim=1,
            largest=True,
            sorted=False,
        )

        for doc_idx, score in zip(top_k_idxs[0], top_k_values[0]):
            doc_id = rec_doc_id[int(doc_idx)]
            # Hybrid scores from BM25 and cosine similarity of embeddings
            results[qid][doc_id] = \
                (1 - beta) * float(score.numpy()) \
                + beta * (bm25_scores[qid][doc_id] - min_bm25_score) / (max_bm25_score - min_bm25_score)

    return results


def evaluate_hybridBM25(query_data, query_embeds, doc_data, doc_embeds, recall_num, beta, cutoffs):
    results = hybridBM25(query_data, query_embeds, doc_data, doc_embeds, recall_num, max(cutoffs), beta)
    ndcg = compute_ranking_score(results=results, qrel_dict=qrel_dict, metrics=["ndcg"], cutoffs=cutoffs)
    return ndcg