In [None]:
import torch
import faiss
import numpy as np
from transformers import AutoTokenizer
from model import SimCSEModel
from dataset import SimCSEDataset
from tqdm import tqdm

In [14]:
MODEL_NAME = "bert-base-uncased"
MAX_LEN = 32
BATCH_SIZE = 64
CHECKPOINT_PATH = './checkpoint/best_model.pth'

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

In [3]:
from datasets import load_dataset

raw_dataset = load_dataset("daily_dialog")
dialogs = raw_dataset['train']['dialog']

sentences = []
for dialog in dialogs:
    sentences.extend(dialog)

sentences = [s.strip() for s in sentences if s.strip() != ""]

README.md:   0%|          | 0.00/7.27k [00:00<?, ?B/s]

daily_dialog.py:   0%|          | 0.00/4.85k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/4.48M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/11118 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/1000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1000 [00:00<?, ? examples/s]

In [4]:
dataset = SimCSEDataset(sentences, tokenizer, max_len=MAX_LEN)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False)

In [15]:
# Load model
model = SimCSEModel(MODEL_NAME).cuda()
model.load_state_dict(torch.load(CHECKPOINT_PATH))

model.eval()

SimCSEModel(
  (backbone): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwi

In [16]:
def get_embeddings(dataloader):
    model.eval()
    embeddings = []
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Embedding"):
            input_ids, attention_mask = batch
            input_ids = input_ids.cuda()
            attention_mask = attention_mask.cuda()
            emb = model(input_ids, attention_mask=attention_mask)
            embeddings.append(emb.cpu().numpy())
    return np.vstack(embeddings)

In [17]:
# Get all embeddings
embeddings = get_embeddings(dataloader)

Embedding: 100%|██████████| 1363/1363 [00:45<00:00, 29.93it/s]


In [18]:
# Faiss index creation
dim = embeddings.shape[1]  # dimension of the embeddings
index = faiss.IndexFlatL2(dim)  # Use L2 distance for similarity
index.add(embeddings)  # Add embeddings to the index

In [19]:
def search(query, k=5):
    query_tokens = tokenizer(query, return_tensors="pt", padding=True, truncation=True, max_length=MAX_LEN)
    query_input_ids = query_tokens['input_ids'].cuda()
    query_attention_mask = query_tokens['attention_mask'].cuda()

    query_embedding = model(query_input_ids, attention_mask=query_attention_mask).detach().cpu().numpy()

    # Search in the Faiss index
    distances, indices = index.search(query_embedding, k)
    return distances, indices

In [20]:
query = "What do you want to eat?"
distances, indices = search(query)

for i, idx in enumerate(indices[0]):
    input_ids, _ = dataset[idx]
    decoded = tokenizer.decode(input_ids, skip_special_tokens=True)

    print(f"Rank {i + 1}: {decoded} | Distance: {distances[0][i]:.4f}")

Rank 1: what do you want to eat? | Distance: 0.0000
Rank 2: what do you eat? | Distance: 29.6953
Rank 3: what are you eating? | Distance: 46.0078
Rank 4: what do you want to eat today? | Distance: 46.6447
Rank 5: what would you like to eat? | Distance: 51.3149
