In [2]:
import pandas as pd
import numpy as np
from annoy import AnnoyIndex 

from transformers import AutoTokenizer, AutoModel
import torch

# Load once (outside the function if embedding multiple times)
tokenizer = AutoTokenizer.from_pretrained("allenai/scibert_scivocab_uncased")
model = AutoModel.from_pretrained("allenai/scibert_scivocab_uncased")
model.eval()  # optional, ensures no dropout during inference
# device = "cuda" if torch.cuda.is_available() else "mps"
# model.to(device)

  from .autonotebook import tqdm as notebook_tqdm


BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(31090, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-11): 12 x BertLayer(
        (attention): BertAttention(
          (self): BertSdpaSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False

In [3]:
def get_scibert_embedding(text: str, pooling: str = "mean") -> torch.Tensor:
    inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
    with torch.no_grad():
        outputs = model(**inputs)
        last_hidden_state = outputs.last_hidden_state  # [1, seq_len, 768]
    
    if pooling == "mean":
        embedding = last_hidden_state.mean(dim=1)  # Average pooling
    elif pooling == "cls":
        embedding = last_hidden_state[:, 0, :]  # [CLS] token
    else:
        raise ValueError("Pooling must be 'mean' or 'cls'")
    return embedding[0]  # shape: [1, 768]


In [4]:
ls = ['deep learning', 'machine learning', 'artificial intelligence', 'natural language processing']

In [5]:
embeddings_combined_dict = {}

for i in ls:
    embeddings_combined_dict[i] = get_scibert_embedding(i, pooling="mean")

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


In [8]:
len(embeddings_combined_dict['deep learning'])

768

In [10]:
alternate_keys = {}
start = 1
for i in embeddings_combined_dict:
    alternate_keys[i] = start
    start += 1

reverse_alternate_keys = {}

for i, j in alternate_keys.items():
    reverse_alternate_keys[j] = i

In [12]:
alternate_keys, reverse_alternate_keys

({'deep learning': 1,
  'machine learning': 2,
  'artificial intelligence': 3,
  'natural language processing': 4},
 {1: 'deep learning',
  2: 'machine learning',
  3: 'artificial intelligence',
  4: 'natural language processing'})

In [13]:
f = 768 # Number of Dimensions
t = AnnoyIndex(f)
for i, j in embeddings_combined_dict.items():
    t.add_item(alternate_keys[i], j)
    
t.build(f)
t.save('entity-search-tree.ann')

  t = AnnoyIndex(f)


True

In [14]:
search_space = AnnoyIndex(768)
search_space.load('./entity-search-tree.ann')

  search_space = AnnoyIndex(768)


True

In [17]:
def text_space_search(query: str, num : int = 10):
    query_vector = get_scibert_embedding(query) 
    ans = search_space.get_nns_by_vector(query_vector, num)
    return ans

In [19]:
for i in text_space_search('deep learning', 5):
    print(reverse_alternate_keys[i])

deep learning
machine learning
artificial intelligence
natural language processing
