# Embedding models

In [1]:
import warnings
warnings.filterwarnings('ignore')

In [2]:
from sentence_transformers import SentenceTransformer

model = SentenceTransformer('all-MiniLM-L6-v2')
model

SentenceTransformer(
  (0): Transformer({'max_seq_length': 256, 'do_lower_case': False}) with Transformer model: BertModel 
  (1): Pooling({'word_embedding_dimension': 384, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False, 'pooling_mode_weightedmean_tokens': False, 'pooling_mode_lasttoken': False, 'include_prompt': True})
  (2): Normalize()
)

In [3]:
tokenized_data = model.tokenize(["walker walked a long walk"])
tokenized_data


{'input_ids': tensor([[ 101, 5232, 2939, 1037, 2146, 3328,  102]]),
 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0]]),
 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1]])}

In [4]:
first_module = model._first_module()
first_module.auto_model

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 384, padding_idx=0)
    (position_embeddings): Embedding(512, 384)
    (token_type_embeddings): Embedding(2, 384)
    (LayerNorm): LayerNorm((384,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-5): 6 x BertLayer(
        (attention): BertAttention(
          (self): BertSdpaSelfAttention(
            (query): Linear(in_features=384, out_features=384, bias=True)
            (key): Linear(in_features=384, out_features=384, bias=True)
            (value): Linear(in_features=384, out_features=384, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=384, out_features=384, bias=True)
            (LayerNorm): LayerNorm((384,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)


In [5]:
embeddings = first_module.auto_model.embeddings
embeddings

BertEmbeddings(
  (word_embeddings): Embedding(30522, 384, padding_idx=0)
  (position_embeddings): Embedding(512, 384)
  (token_type_embeddings): Embedding(2, 384)
  (LayerNorm): LayerNorm((384,), eps=1e-12, elementwise_affine=True)
  (dropout): Dropout(p=0.1, inplace=False)
)

In [9]:
import torch
import plotly.express as px

device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
device = 'mps'


first_sentence = "vector search optimization"
second_sentence = "we learn about vector search optimization"

with torch.no_grad():
    first_tokens = model.tokenize([first_sentence])
    second_tokens = model.tokenize([second_sentence])

     # Convert the input_ids tensors to the device
    first_embedding = embeddings.word_embeddings(first_tokens['input_ids'].to(device))
    second_embedding = embeddings.word_embeddings(second_tokens['input_ids'].to(device))


first_embedding.shape, second_embedding.shape



(torch.Size([1, 5, 384]), torch.Size([1, 8, 384]))

In [12]:
from sentence_transformers import util



distances= util.cos_sim(first_embedding.squeeze(), second_embedding.squeeze()).cpu().numpy()
distances = distances.T
px.imshow(
    distances,
    x=model.tokenizer.convert_ids_to_tokens(
        first_tokens['input_ids'][0]
        ),
    y=model.tokenizer.convert_ids_to_tokens(
        second_tokens['input_ids'][0]
        ),
    text_auto=True,
)


In [13]:
token_embeddings = first_module.auto_model \
                    .embeddings \
                    .word_embeddings \
                    .weight \
                    .detach() \
                    .cpu() \
                    .numpy()

token_embeddings.shape

(30522, 384)

In [14]:
import random

vocabulary = first_module.tokenizer.get_vocab()
sorted_vocabulary = sorted(
    vocabulary.items(),
    key=lambda x: x[1]
)

sorted_tokens = [token for token, _ in sorted_vocabulary]
random.choices(sorted_tokens, k=100)


['whorls',
 'teachings',
 'blond',
 '[unused421]',
 'tempered',
 '##oed',
 'series',
 'mickey',
 'periodical',
 'outstanding',
 'inflicted',
 'offs',
 'flying',
 'embracing',
 'growling',
 'majors',
 'adolescents',
 'geological',
 'huddled',
 'safer',
 'attributes',
 '‡',
 'societal',
 'politically',
 'vigor',
 'offenses',
 'depart',
 'documentation',
 'cello',
 '##dora',
 'overdose',
 'pagoda',
 'draper',
 'translations',
 '[unused236]',
 'aspect',
 'affairs',
 'fathers',
 '276',
 '##riding',
 'automated',
 'leukemia',
 'troll',
 'creeks',
 'imminent',
 'hated',
 'founders',
 'residency',
 '##ː',
 '##economic',
 '1920',
 'hostile',
 'francesco',
 'robbins',
 'downloaded',
 'francais',
 'accepting',
 '##app',
 'twinkle',
 'satisfy',
 '##⁴',
 'myanmar',
 'irony',
 'adjustments',
 'flanders',
 'solved',
 'eireann',
 'doncaster',
 '##olio',
 'promotes',
 'distal',
 '##ी',
 'nixon',
 '##igh',
 'straw',
 'bradley',
 'eel',
 'thirty',
 '和',
 'cia',
 'lovely',
 'enroll',
 '##erate',
 'landmar

In [15]:
from sklearn.manifold import TSNE

tsne = TSNE(n_components=2, random_state=42, metric='cosine')
tsne_embeddings_2d = tsne.fit_transform(token_embeddings)
tsne_embeddings_2d.shape


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


(30522, 2)

In [16]:
token_colors = []

for token in sorted_tokens:
    if token[0] == "[" and token[-1] == "]":
        token_colors.append("red")
    elif token.startswith("##"):
        token_colors.append("blue")
    else:
        token_colors.append("green")


In [18]:
import plotly.graph_objs as go

scatter = go.Scatter(
    x=tsne_embeddings_2d[:, 0],
    y=tsne_embeddings_2d[:, 1],
    text=sorted_tokens,
    mode="markers",
    marker=dict(
        color=token_colors,
        size=3,
    ),
)

fig = go.Figure(data=[scatter],
                layout=dict(
                    width=600,
                    height=900,
                    margin=dict(l=0, r=0,),
                )
)

fig.show()

