In [13]:
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModel
from typing import List, Dict, Optional
from pydantic import BaseModel, Field, ConfigDict

In [14]:
class Chunk(BaseModel):
    """Represents a text chunk with its embedding."""
    model_config = ConfigDict(arbitrary_types_allowed=True)
    
    text: str
    start_index: int
    end_index: int
    token_count: int
    embedding: np.ndarray

In [16]:
# Set device
device = None
# model_name = "sentence-transformers/all-MiniLM-L6-v2"
# model_name="BAAI/bge-small-en-v1.5"
model_name = "BAAI/bge-m3"

if device is None:
    if torch.cuda.is_available():
        device = torch.device('cuda')
    elif torch.backends.mps.is_available():
        device = torch.device('mps')
    else:
        device = torch.device('cpu')
else:
    device = torch.device(device)

print(f"Using device: {device}")

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name).to(device)


Using device: mps


tokenizer_config.json:   0%|          | 0.00/444 [00:00<?, ?B/s]

sentencepiece.bpe.model:   0%|          | 0.00/5.07M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.1M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/964 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/687 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/2.27G [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.27G [00:00<?, ?B/s]

In [17]:
def _get_token_embeddings(text: str) -> tuple[np.ndarray, List[int]]:
    """
    Get token-level embeddings for the entire text.
    """
    # Tokenize the text
    tokens = tokenizer(
        text, 
        return_tensors='pt', 
        truncation=False,
        padding=False,
    )
    
    input_ids = tokens['input_ids'].to(device)
    attention_mask = tokens['attention_mask'].to(device)

    # Get embeddings
    with torch.no_grad():
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        # Use last hidden state for token embeddings
        token_embeddings = outputs.last_hidden_state[0]  # Remove batch dimension

    # Convert to numpy
    token_embeddings = token_embeddings.cpu().numpy()
    token_ids = input_ids[0].cpu().tolist()

    print("Token embeddings shape:", token_embeddings.shape)
    print("Token IDs:", token_ids)
    print("Token embeddings:", token_embeddings)
    
    return token_embeddings, token_ids

In [18]:
def _create_chunks_by_tokens(text: str, token_ids: List[int], chunk_size: int, overlap: int) -> List[Dict]:

    """
        Create chunks based on token count with overlap.
    """
    chunks = []
    num_tokens = len(token_ids)

    # Skip special tokens (CLS, SEP, etc.)
    start_offset = 1 if token_ids[0] == tokenizer.cls_token_id else 0
    end_offset = 1 if token_ids[-1] == tokenizer.sep_token_id else 0
    
    start_token = start_offset

    while start_token < num_tokens - end_offset:
        end_token = min(start_token + chunk_size, num_tokens - end_offset)
        
        # Decode tokens to get text
        chunk_token_ids = token_ids[start_token:end_token]
        chunk_text = tokenizer.decode(chunk_token_ids, skip_special_tokens=True)
        
        # Find character positions (approximate)
        decoded_so_far = tokenizer.decode(
            token_ids[start_offset:start_token],
            skip_special_tokens=True
        )
        char_start = len(decoded_so_far)
        char_end = char_start + len(chunk_text)
        
        chunks.append({
            'text': chunk_text,
            'start_char': char_start,
            'end_char': char_end,
            'start_token': start_token,
            'end_token': end_token,
            'token_count': end_token - start_token
        })
        
        # Move to next chunk with overlap
        start_token = end_token - overlap
        
        # Break if we've covered all tokens
        if end_token >= num_tokens - end_offset:
            break
    
    return chunks



In [19]:
def _pool_chunk_embeddings(
    token_embeddings: np.ndarray,
    start_token: int,
    end_token: int
) -> np.ndarray:
    
    """Pool token embeddings for a chunk using mean pooling."""
    
    chunk_embeddings = token_embeddings[start_token:end_token]

    return np.mean(chunk_embeddings, axis=0)

In [20]:
def chunk_text(text:str, chunk_size:int, overlap:int) -> List[Chunk]:
    
    if not text.strip():
        return []

    # Step 1: Get token-level embeddings for entire text
    token_embeddings, token_ids = _get_token_embeddings(text)

    #Step 2: Create chunks based on token boundaries
    chunk_info = _create_chunks_by_tokens(text, token_ids, chunk_size, overlap)


    # Step 3: Pool token embeddings for each chunk (Mean Pooling)
    chunks = []
    for info in chunk_info:
        # Average the token embeddings within this chunk
        chunk_embedding = _pool_chunk_embeddings(
            token_embeddings,
            info['start_token'],
            info['end_token']
        )
        
        chunks.append(Chunk(
            text=info['text'],
            start_index=info['start_char'],
            end_index=info['end_char'],
            token_count=info['token_count'],
            embedding=chunk_embedding
        ))
    
    return chunks

In [21]:
def get_embeddings_matrix(chunks: List[Chunk]) -> np.ndarray:
    """
    Extract embeddings from chunks as a matrix.
    
    Args:
        chunks: List of Chunk objects
        
    Returns:
        numpy array of shape (num_chunks, embedding_dim)
    """
    return np.array([chunk.embedding for chunk in chunks])

In [22]:
# Load research paper
with open("../data/sample_research_paper.txt", "r") as f:
    research_paper_text = f.read()

# Chunk the text
chunks = chunk_text(research_paper_text, chunk_size=128, overlap=20)

# Display results
print(f"Created {len(chunks)} chunks:\n")
for i, chunk in enumerate(chunks):
    print(f"Chunk {i+1}:")
    print(f"  Text: {chunk.text}")
    print(f"  Tokens: {chunk.token_count}")
    print(f"  Embedding shape: {chunk.embedding.shape}")
    print(f"  Position: [{chunk.start_index}:{chunk.end_index}]")
    print()

# Get embeddings as matrix
embeddings = get_embeddings_matrix(chunks)
print(f"Embeddings matrix shape: {embeddings.shape}")

Token embeddings shape: (3295, 1024)
Token IDs: [0, 10232, 109109, 183501, 449, 141075, 7, 100, 853, 97351, 1405, 9, 186432, 1183, 71, 83479, 84837, 233973, 853, 97351, 1405, 9, 186432, 1183, 71, 83479, 15, 12280, 724, 16, 76519, 765, 74216, 71, 237, 10, 113138, 214709, 100, 22, 1121, 21896, 21334, 46876, 115774, 678, 173591, 51359, 5, 62, 130306, 82761, 111, 19400, 724, 76519, 83, 7986, 7839, 6048, 20, 70, 9433, 111, 33180, 214, 60525, 3934, 111240, 2886, 25072, 7, 100, 6, 55720, 59725, 136, 456, 97351, 1405, 5, 3293, 15122, 13379, 7, 10, 199083, 114137, 111, 484, 109109, 7839, 6048, 50531, 7, 4, 37397, 214, 89160, 188347, 9, 62539, 51515, 90, 678, 5744, 6, 55720, 59725, 9, 77007, 150624, 5, 1401, 151575, 13, 300, 86, 117781, 7839, 6048, 234873, 7, 36880, 48716, 158208, 26719, 484, 109109, 552, 3334, 6620, 4, 456, 97351, 1405, 61689, 219, 2408, 4, 136, 181135, 43315, 227066, 5, 22929, 28007, 7, 106804, 13, 450, 39908, 5844, 484, 109109, 7839, 6048, 1810, 1264, 5037, 7, 188347, 9, 6253