In [1]:
# Libraries

import torch
from transformers import AutoTokenizer, AutoModel
import numpy as np
import torch.nn.functional as F
import faiss  # for efficient similarity search

In [2]:
# Load tokenizer and model

MODEL_NAME = "bert-base-uncased"

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModel.from_pretrained(MODEL_NAME)
model.eval()  # Set model to evaluation mode

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-11): 12 x BertLayer(
        (attention): BertAttention(
          (self): BertSdpaSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False

In [3]:
# Encode Text function

def encode_text(texts, tokenizer, model, max_length=128, device="cpu"):
    """Encodes a list of texts into contextual embeddings."""
    # Tokenize input texts
    inputs = tokenizer(
        texts,
        padding="max_length",
        truncation=True,
        max_length=max_length,
        return_tensors="pt"
    )
    
    # Move data to device
    inputs = {key: val.to(device) for key, val in inputs.items()}
    
    # Forward pass through the model
    with torch.no_grad():
        outputs = model(**inputs)
    
    # Use last hidden state (token-level embeddings)
    token_embeddings = outputs.last_hidden_state  # Shape: (batch_size, seq_len, hidden_dim)
    return token_embeddings


In [4]:
# Index Documents

def index_documents(documents, tokenizer, model, device="cpu"):
    """Precomputes document embeddings and indexes them."""
    # Encode documents
    document_embeddings = []
    for doc in documents:
        token_embeds = encode_text([doc], tokenizer, model, device=device)
        # Aggregate embeddings (e.g., mean-pooling or use token embeddings directly)
        doc_emb = token_embeds.mean(dim=1)  # Use mean of all tokens
        document_embeddings.append(doc_emb.squeeze().cpu().numpy())
    
    # Convert embeddings to FAISS index
    embedding_matrix = np.stack(document_embeddings)
    index = faiss.IndexFlatL2(embedding_matrix.shape[1])  # L2 distance index
    index.add(embedding_matrix)
    
    return index, document_embeddings


In [5]:
# Max Sim Function

def compute_max_similarity(query_embeds, doc_embeds):
    """Compute max-sim using cosine similarity between query and document embeddings."""
    # Normalize embeddings to unit vectors
    query_embeds = F.normalize(query_embeds, p=2, dim=1)  # Shape: (seq_len_query, hidden_dim)
    doc_embeds = F.normalize(doc_embeds, p=2, dim=1)      # Shape: (seq_len_doc, hidden_dim)

    # Compute similarity matrix
    similarity_matrix = torch.matmul(query_embeds, doc_embeds.T)  # Shape: (seq_len_query, seq_len_doc)

    # Take max similarity for each query token
    max_sim_per_query_token, _ = similarity_matrix.max(dim=1)  # Shape: (seq_len_query)

    # Aggregate scores (e.g., summing or mean)
    total_similarity = max_sim_per_query_token.sum()  # Scalar
    return total_similarity.item()

In [6]:
# Retrieval with Late Interaction

def retrieve_with_late_interaction(query, documents, tokenizer, model, top_k=5, device="cpu"):
    """Retrieve documents using max-sim late interaction."""
    # Encode query tokens
    query_embeddings = encode_text([query], tokenizer, model, device=device).squeeze(0)  # Shape: (seq_len_query, hidden_dim)

    # Compute max-sim score for each document
    scores = []
    for doc in documents:
        doc_embeddings = encode_text([doc], tokenizer, model, device=device).squeeze(0)  # Shape: (seq_len_doc, hidden_dim)
        score = compute_max_similarity(query_embeddings, doc_embeddings)
        scores.append(score)

    # Rank documents by similarity
    ranked_indices = np.argsort(scores)[::-1][:top_k]
    ranked_results = [(documents[i], scores[i]) for i in ranked_indices]

    return ranked_results


In [7]:
# Sample documents

documents = [
    "The Eiffel Tower is in Paris.",
    "The Mona Lisa is in the Louvre.",
    "The Great Wall of China is a historic site.",
    "Mount Everest is the tallest mountain on Earth."
]

In [8]:
# Precompute embeddings and index them

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
index, doc_embeddings = index_documents(documents, tokenizer, model, device=device)

In [9]:
# Query the documents with late interaction

query = "Where is the Eiffel Tower?"
results = retrieve_with_late_interaction(query, documents, tokenizer, model, top_k=1, device=device)

In [10]:
# Results
print("Query:", query)
print("Results:")
for doc, score in results:
    print(f"- {doc} (Max Sim Score: {score:.4f})")

Query: Where is the Eiffel Tower?
Results:
- The Eiffel Tower is in Paris. (Max Sim Score: 97.0134)
