In [8]:
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModel
from typing import List, Dict, Optional
from pydantic import BaseModel, Field, ConfigDict
import torch.nn.functional as F
import pandas as pd

In [2]:
# Set device
device = None
# model_name = "sentence-transformers/all-MiniLM-L6-v2"
# model_name="BAAI/bge-small-en-v1.5"
model_name = "BAAI/bge-m3"

if device is None:
    if torch.cuda.is_available():
        device = torch.device('cuda')
    elif torch.backends.mps.is_available():
        device = torch.device('mps')
    else:
        device = torch.device('cpu')
else:
    device = torch.device(device)

print(f"Using device: {device}")

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name).to(device)


Using device: mps


In [3]:
def _get_token_embeddings(text: str):
    """Passes the full text through the model to get individual token embeddings."""
    inputs = tokenizer(
        text, 
        return_tensors="pt", 
        truncation=True, 
        max_length=8192).to(device)
    
    with torch.no_grad():
        model_output = model(**inputs)
        # Take the last hidden state: [batch_size, sequence_length, embedding_dim]
        token_embeddings = model_output.last_hidden_state[0] 
        
    return token_embeddings, inputs['input_ids'][0]

In [4]:
def chunk_text_with_embeddings(text: str, chunk_size: int = 100, overlap: int = 20):
    """
    Chunks the text AFTER embedding.
    Returns a list of dictionaries containing text and its 'late' embedding.
    """
    # 1. Get contextualized token embeddings for the WHOLE document
    token_embs, input_ids = _get_token_embeddings(text)
    
    # Remove special tokens (CLS/SEP) for cleaner chunking if necessary
    # Here we keep them or filter based on tokenizer.all_special_ids
    
    total_tokens = len(input_ids)
    chunks = []
    
    # 2. Slice the token embeddings into chunks
    start = 0
    while start < total_tokens:
        end = min(start + chunk_size, total_tokens)
        
        # Extract the tokens and embeddings for this slice
        chunk_token_ids = input_ids[start:end]
        chunk_token_embs = token_embs[start:end]
        
        # 3. "Late" Pooling: Mean pool the contextualized embeddings
        # This embedding now contains info from the context OUTSIDE this chunk
        late_chunk_emb = torch.mean(chunk_token_embs, dim=0)
        
        # Normalize for cosine similarity compatibility
        late_chunk_emb = F.normalize(late_chunk_emb, p=2, dim=0)
        
        # Decode tokens back to text
        chunk_text = tokenizer.decode(chunk_token_ids, skip_special_tokens=True)
        
        chunks.append({
            "text": chunk_text,
            "embedding": late_chunk_emb.cpu().numpy(),
            "token_range": (start, end)
        })
        
        if end == total_tokens:
            break
        start += (chunk_size - overlap)
        
    return chunks

In [20]:
# Load research paper
with open("../data/sample_research_paper.txt", "r") as f:
    research_paper_text = f.read()

chunks = chunk_text_with_embeddings(research_paper_text, chunk_size=512, overlap=64)

print(f"Created {len(chunks)} chunks:\n")

# Convert to Pandas for easy viewing
df = pd.DataFrame(chunks)
print(df[['text', 'token_range']].head(10))

# Display results
for i, chunk in enumerate(chunks):
    print(f"Chunk {i+1}:")
    print(f"  Text: {chunk['text']}")

Created 8 chunks:

                                                text   token_range
0  Semantic Chunking Strategies for Retrieval-Aug...      (0, 512)
1  d responses. Traditional chunking approaches e...    (448, 960)
2  a framework for representing hierarchical disc...   (896, 1408)
3  unrelated sentences. Recursive Chunking: Appli...  (1344, 1856)
4  with ground-truth section boundaries, we compu...  (1792, 2304)
5  4.2 Retrieval Performance Retrieval accuracy l...  (2240, 2752)
6  gy Selection Guidelines Based on our findings,...  (2688, 3200)
7  d by application requirements, document charac...  (3136, 3295)
Chunk 1:
  Text: Semantic Chunking Strategies for Retrieval-Augmented Generation Systems Abstract Retrieval-Augmented Generation (RAG) systems have emerged as a powerful paradigm for enhancing large language models with external knowledge. A critical component of RAG systems is text chunking - the process of segmenting documents into manageable units for embedding and retriev