In [1]:
import torch
import torch.nn.functional as F
from torch.optim import AdamW
from transformers import AutoTokenizer, AutoModel

# Load model and tokenizer
MODEL_PATH ="/Users/sir/Downloads/HuggingFace/sentence_transformer/intfloat_e5-large-v2"
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
model = AutoModel.from_pretrained(MODEL_PATH)

# Example batch of paired inputs: (query, positive passage)
texts = [
    ("how much protein should a female eat", "The CDC average protein requirement for women is 46 grams per day."),
    ("define summit", "Definition of summit is the highest point of a mountain.")
]

In [6]:
print(model)

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 1024, padding_idx=0)
    (position_embeddings): Embedding(512, 1024)
    (token_type_embeddings): Embedding(2, 1024)
    (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-23): 24 x BertLayer(
        (attention): BertAttention(
          (self): BertSdpaSelfAttention(
            (query): Linear(in_features=1024, out_features=1024, bias=True)
            (key): Linear(in_features=1024, out_features=1024, bias=True)
            (value): Linear(in_features=1024, out_features=1024, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=1024, out_features=1024, bias=True)
            (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, 

In [10]:
# --- E5 SPECIFIC HELPER FUNCTIONS ---

def prefix_input(texts):
    """Adds the required prefix to queries and passages."""
    queries = [f"query: {q}" for q, p in texts]
    passages = [f"passage: {p}" for q, p in texts]
    return queries, passages

def average_pool(last_hidden_states, attention_mask):
    """
    Performs masked average pooling over the token embeddings.
    This function handles the variable sequence length by averaging only the 
    tokens that were actually part of the input (where attention_mask is 1).
    """
    # Sum of token embeddings, masked by the attention mask
    last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
    
    # Calculate the sum of token embeddings
    sum_embeddings = last_hidden.sum(dim=1)
    
    # Calculate the count of non-masked tokens for division
    sum_mask = attention_mask.sum(dim=1).unsqueeze(-1)
    
    # Return the average embedding, handling potential division by zero
    return sum_embeddings / torch.clamp(sum_mask, min=1e-9)

def get_embeddings(texts, is_query=False):
    """Generates the final L2-normalized, pooled embeddings."""
    
    # Tokenization
    inputs = tokenizer(
        texts, 
        max_length=512, 
        padding=True, 
        truncation=True, 
        return_tensors='pt'
    )
    
    # Model Forward Pass
    with torch.no_grad():
        outputs = model(**inputs)
        
    # Pooling
    # The E5 paper recommends using the average of the last hidden states
    embeddings = average_pool(outputs.last_hidden_state, inputs['attention_mask'])
    
    # Normalization (Crucial for E5 models to make Cosine Sim == Dot Product)
    embeddings = F.normalize(embeddings, p=2, dim=1)
    
    return embeddings

# --- EXECUTION ---

print("--- Step 1: Prepare and Prefix Inputs ---")
queries, passages = prefix_input(texts)
print(f"Example Query Input: {queries[0]}")
print(f"Example Passage Input: {passages[0]}")
print("-" * 30)

print("--- Step 2: Generate L2-Normalized Embeddings ---")
query_embeds = get_embeddings(queries, is_query=True)
passage_embeds = get_embeddings(passages, is_query=False)

# Check the output shape and normalization:
print(f"Query Embeddings Shape: {query_embeds.shape}")
print(f"Passage Embeddings Shape: {passage_embeds.shape}") 
# The expected dimension is 1024, as previously discussed.

print("-" * 30)

# --- Step 3: Calculate Relevance (Cosine Similarity) ---
# Since embeddings are L2 normalized, Cosine Similarity is equivalent to the Dot Product
# using torch.mm() for matrix multiplication.

# We only care about the similarity between the corresponding (query, passage) pairs
# So we use the diagonal of the dot product matrix.
similarity_scores = torch.sum(query_embeds * passage_embeds, dim=1) 

print("--- Final Relevance Scores (Cosine Similarity) ---")
for i, score in enumerate(similarity_scores):
    print(f"Pair {i+1} ('{texts[i][0]}'): {score.item():.4f}")
    
# High scores (closer to 1.0) indicate high relevance.

--- Step 1: Prepare and Prefix Inputs ---
Example Query Input: query: how much protein should a female eat
Example Passage Input: passage: The CDC average protein requirement for women is 46 grams per day.
------------------------------
--- Step 2: Generate L2-Normalized Embeddings ---
Query Embeddings Shape: torch.Size([2, 1024])
Passage Embeddings Shape: torch.Size([2, 1024])
------------------------------
--- Final Relevance Scores (Cosine Similarity) ---
Pair 1 ('how much protein should a female eat'): 0.9487
Pair 2 ('define summit'): 0.9562


In [2]:
# Tokenize batch
queries, positives = zip(*texts)
query_inputs = tokenizer(list(queries), return_tensors='pt', padding=True, truncation=True)
positive_inputs = tokenizer(list(positives), return_tensors='pt', padding=True, truncation=True)

In [7]:
queries, positives

(('how much protein should a female eat', 'define summit'),
 ('The CDC average protein requirement for women is 46 grams per day.',
  'Definition of summit is the highest point of a mountain.'))

In [8]:
query_inputs

{'input_ids': tensor([[ 101, 2129, 2172, 5250, 2323, 1037, 2931, 4521,  102],
        [ 101, 9375, 6465,  102,    0,    0,    0,    0,    0]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 0, 0, 0, 0, 0]])}

In [9]:
positive_inputs

{'input_ids': tensor([[  101,  1996, 26629,  2779,  5250,  9095,  2005,  2308,  2003,  4805,
         20372,  2566,  2154,  1012,   102],
        [  101,  6210,  1997,  6465,  2003,  1996,  3284,  2391,  1997,  1037,
          3137,  1012,   102,     0,     0]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]])}

In [3]:
# Forward pass to get embeddings
query_embeds = model(**query_inputs).last_hidden_state.mean(dim=1)
positive_embeds = model(**positive_inputs).last_hidden_state.mean(dim=1)

# Normalize embeddings
query_embeds = F.normalize(query_embeds, p=2, dim=1)
positive_embeds = F.normalize(positive_embeds, p=2, dim=1)

# Compute cosine similarity matrix
similarity_matrix = torch.matmul(query_embeds, positive_embeds.t())

# Create labels (diagonal matches)
labels = torch.arange(len(query_embeds)).to(similarity_matrix.device)

# Temperature parameter to scale similarities
temperature = 0.05
logits = similarity_matrix / temperature

# Contrastive loss with cross-entropy
loss = F.cross_entropy(logits, labels)

# Backprop and optimization
optimizer = AdamW(model.parameters(), lr=5e-5)
optimizer.zero_grad()
loss.backward()
optimizer.step()

print(f"Contrastive loss: {loss.item()}")

Contrastive loss: 0.027333537116646767


In [4]:

query_embeds

tensor([[ 0.0184, -0.0608,  0.0260,  ..., -0.0548,  0.0031,  0.0019],
        [ 0.0076, -0.0457,  0.0048,  ..., -0.0279,  0.0073,  0.0357]],
       grad_fn=<DivBackward0>)