# NER Powered Semantic Search

This notebook shows how to use Named Entity Recognition (NER) for vector search with LanceDB. We will:

1. Extract named entities from text.
2. Store them in a LanceDB as metadata (alongside respective text vectors).
3. We extract named entities from incoming queries and use them to filter and search only through records containing these named entities.

This is particularly helpful if you want to restrict the search score to records that contain information about the named entities that are also found within the query.

Let's get started.

# Installing Dependencies

In [1]:
!pip install sentence_transformers datasets lancedb -qU

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m86.0/86.0 kB[0m [31m2.4 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m493.7/493.7 kB[0m [31m16.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m72.0/72.0 kB[0m [31m9.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.9/7.9 MB[0m [31m96.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m86.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m311.2/311.2 kB[0m [31m36.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m115.3/115.3 kB[0m [31m14.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.8/134.8 kB[0m [31m17.2 M

# Load and Prepare Datasets

We use a dataset containing ~190K articles scraped from Medium. We select 50K articles from the dataset as indexing all the articles may take some time. This dataset can be loaded from the HuggingFace dataset hub as follows:

In [2]:
from datasets import load_dataset

# load the dataset and convert to pandas dataframe
df = load_dataset(
    "fabiochiu/medium-articles", data_files="medium_articles.csv", split="train"
).to_pandas()

Downloading readme:   0%|          | 0.00/2.26k [00:00<?, ?B/s]

Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/1.04G [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

  table = cls._concat_blocks(blocks, axis=0)


## Preprocessing on dataset

In [3]:
# drop empty rows and select 20k articles
df = df.dropna().sample(20000, random_state=32)
df.head()

Unnamed: 0,title,text,url,authors,timestamp,tags
4172,How the Data Stole Christmas,by Anonymous\n\nThe door sprung open and our t...,https://medium.com/data-ops/how-the-data-stole...,[],2019-12-24 13:22:33.143000+00:00,"['Data Science', 'Big Data', 'Dataops', 'Analy..."
174868,Automating Light Switch using the ESP32 Board ...,A story about how I escaped the boring task th...,https://python.plainenglish.io/automating-ligh...,['Tomas Rasymas'],2021-09-14 07:20:52.342000+00:00,"['Programming', 'Python', 'Software Developmen..."
100171,Keep Going Quotes Sayings for When Hope is Lost,It’s a very thrilling thing to achieve a goal....,https://medium.com/@yourselfquotes/keep-going-...,['Yourself Quotes'],2021-01-05 12:13:04.018000+00:00,['Quotes']
141757,When Will the Smoke Clear From Bay Area Skies?,Bay Area cities are contending with some of th...,https://thebolditalic.com/when-will-the-smoke-...,['Matt Charnock'],2020-09-15 22:38:33.924000+00:00,"['Bay Area', 'San Francisco', 'California', 'W..."
183489,"The ABC’s of Sustainability… easy as 1, 2, 3",By Julia DiPrete\n\n(according to the Jackson ...,https://medium.com/sipwines/the-abcs-of-sustai...,['Sip Wines'],2021-03-02 23:39:49.948000+00:00,"['Wine Tasting', 'Sustainability', 'Wine']"


In [4]:
# select first 1000 characters
df["text"] = df["text"].str[:1000]
# join article title and the text
df["title_text"] = df["title"] + ". " + df["text"]

## Initialize NER model

To extract named entities, we will use a NER model finetuned on a BERT-base model. The model can be loaded from the HuggingFace model hub as follows:

In [5]:
import torch

# set device to GPU if available
device = torch.cuda.current_device() if torch.cuda.is_available() else None

In [6]:
from transformers import AutoTokenizer, AutoModelForTokenClassification
from transformers import pipeline

model_id = "dslim/bert-base-NER"

# load the tokenizer from huggingface
tokenizer = AutoTokenizer.from_pretrained(model_id)
# load the NER model from huggingface
model = AutoModelForTokenClassification.from_pretrained(model_id)
# load the tokenizer and model into a NER pipeline
nlp = pipeline(
    "ner", model=model, tokenizer=tokenizer, aggregation_strategy="max", device=device
)

Downloading (…)okenizer_config.json:   0%|          | 0.00/59.0 [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/829 [00:00<?, ?B/s]

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/213k [00:00<?, ?B/s]

Downloading (…)in/added_tokens.json:   0%|          | 0.00/2.00 [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

Downloading model.safetensors:   0%|          | 0.00/433M [00:00<?, ?B/s]

Some weights of the model checkpoint at dslim/bert-base-NER were not used when initializing BertForTokenClassification: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']
- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [7]:
text = "What are the best Places to visit in London"
# use the NER pipeline to extract named entities from the text
nlp(text)

[{'entity_group': 'LOC',
  'score': 0.99969244,
  'word': 'London',
  'start': 37,
  'end': 43}]

Our NER pipeline is working as expected and accurately extracting entities from the text.

## Initialize Retreiver

A retriever model is used to embed passages (article title + first 1000 characters) and queries. It creates embeddings such that queries and passages with similar meanings are close in the vector space. We will use a sentence-transformer model as our retriever. The model can be loaded as follows:

In [8]:
from sentence_transformers import SentenceTransformer

# load the model from huggingface
retriever = SentenceTransformer(
    "flax-sentence-embeddings/all_datasets_v3_mpnet-base", device=device
)
retriever

Downloading (…)e933c/.gitattributes:   0%|          | 0.00/737 [00:00<?, ?B/s]

Downloading (…)_Pooling/config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

Downloading (…)cbe6ee933c/README.md:   0%|          | 0.00/9.85k [00:00<?, ?B/s]

Downloading (…)e6ee933c/config.json:   0%|          | 0.00/591 [00:00<?, ?B/s]

Downloading (…)ce_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

Downloading (…)33c/data_config.json:   0%|          | 0.00/15.7k [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/438M [00:00<?, ?B/s]

Downloading (…)nce_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/239 [00:00<?, ?B/s]

Downloading (…)e933c/tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/383 [00:00<?, ?B/s]

Downloading (…)933c/train_script.py:   0%|          | 0.00/13.2k [00:00<?, ?B/s]

Downloading (…)cbe6ee933c/vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading (…)6ee933c/modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

SentenceTransformer(
  (0): Transformer({'max_seq_length': 128, 'do_lower_case': False}) with Transformer model: MPNetModel 
  (1): Pooling({'word_embedding_dimension': 768, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False})
  (2): Normalize()
)

## Initialize LanceDB

In [10]:
import lancedb

db = lancedb.connect("./.lancedb")

## Generate Embeddings and Insert

We generate embeddings for the title_text column we created earlier. Alongside the embeddings, we also include the named entities in the index as metadata. Later we will apply a filter based on these named entities when executing queries.

Let's first write a helper function to extract named entities from a batch of text.

In [9]:
def extract_named_entities(text_batch):
    # extract named entities using the NER pipeline
    extracted_batch = nlp(text_batch)
    entities = []
    # loop through the results and only select the entity names
    for text in extracted_batch:
        ne = [entity["word"] for entity in text]
        entities.append(ne)
    return entities

In [11]:
from tqdm.auto import tqdm
import warnings
import pandas as pd
import numpy as np

warnings.filterwarnings("ignore", category=UserWarning)

# we will use batches of 64
batch_size = 64
data = []
from collections import defaultdict

# table_data = defaultdict(list)


for i in tqdm(range(0, len(df), batch_size)):
    # find end of batch
    i_end = min(i + batch_size, len(df))
    # extract batch
    batch = df.iloc[i:i_end].copy()
    # generate embeddings for batch
    emb = retriever.encode(batch["title_text"].tolist()).tolist()
    # extract named entities from the batch
    entities = extract_named_entities(batch["title_text"].tolist())
    # remove duplicate entities from each record
    batch["named_entities"] = [list(set(entity)) for entity in entities]
    batch = batch.drop("title_text", axis=1)
    # get metadata
    meta = batch.to_dict(orient="records")
    # create unique IDs
    ids = [f"{idx}" for idx in range(i, i_end)]
    # add all to upsert list
    to_upsert = list(zip(ids, emb, meta, batch["named_entities"]))
    for id, emb, meta, entity in to_upsert:
        temp = {}

        temp["vector"] = np.array(emb)
        temp["metadata"] = meta
        temp["named_entities"] = entity
        data.append(temp)

  0%|          | 0/313 [00:00<?, ?it/s]

In [12]:
# create table using above data
tbl = db.create_table("tw", data)
# check table
tbl.head()

pyarrow.Table
vector: fixed_size_list<item: float>[768]
  child 0, item: float
metadata: struct<authors: string, named_entities: list<item: string>, tags: string, text: string, timestamp: string, title: string, url: string>
  child 0, authors: string
  child 1, named_entities: list<item: string>
      child 0, item: string
  child 2, tags: string
  child 3, text: string
  child 4, timestamp: string
  child 5, title: string
  child 6, url: string
named_entities: list<item: string>
  child 0, item: string
----
vector: [[[-0.009049614,0.10612086,-0.027753588,0.07209486,-0.032509252,...,0.11016317,-0.013526588,-0.0046699173,0.035262693,-0.051537305],[-0.04104508,-0.049538508,-0.026324937,0.019106576,-0.017135208,...,-0.050371084,-0.058374014,0.014137886,-0.046907514,-0.012160475],[0.0068191127,0.05442695,0.0059523294,-0.0272331,0.05366467,...,0.05989369,-0.02457071,-0.01919812,0.059475537,-0.040533062],[-0.074406564,0.06398625,-0.0032167286,0.0006136286,-0.04038913,...,-0.0035826706,0.0176

# Quering

In [13]:
from pprint import pprint


def search_lancedb(query):
    # extract named entities from the query
    ne = extract_named_entities([query])[0]
    # create embeddings for the query
    xq = retriever.encode(query).tolist()
    # query the lancedb table while applying named entity filter
    xc = tbl.search(xq).to_list()
    # extract article titles from the search result
    r = [
        x["metadata"]["title"]
        for x in xc
        for i in x["metadata"]["named_entities"]
        if i in ne
    ]
    return pprint({"Extracted Named Entities": ne, "Result": r})

Now lets try quering

In [14]:
query = "How Data is changing world?"
search_lancedb(query)

{'Extracted Named Entities': ['Data'],
 'Result': ['Data Science is all about making the right choices']}


In [15]:
query = "Why does SpaceX want to build a city on Mars?"
search_lancedb(query)

{'Extracted Named Entities': ['SpaceX', 'Mars'],
 'Result': ['Mars Habitat: NASA 3D Printed Habitat Challenge',
            'Reusable rockets and the robots at sea: The SpaceX story',
            'Reusable rockets and the robots at sea: The SpaceX story',
            'Colonising Planets Beyond Mars',
            'Colonising Planets Beyond Mars',
            'Musk Explained: The Musk approach to marketing',
            'How We’ll Access the Water on Mars',
            'Chasing Immortality',
            'Mission Possible: How Space Exploration Can Deliver Sustainable '
            'Development']}


These all look like great results, making the most of LanceDB advanced vector search capabilities