# COMP 631 Project Task 2
Fei Teng (ft28), Lingyi Xu (lx28)

March 23rd, 2025

In [21]:
# import necessary package for dunzhang/stella_en_400M_v5
!pip install xformers



## Building the Precomputed Index
This part loads a large corpus from a JSON file using stream parsing (via ijson) to avoid excessive memory usage. It uses the StellaEmbedding class to encode document texts (concatenated title and text) into fixed-length embeddings. The documents are processed in small chunks (e.g., 1000 documents per chunk) and the resulting embeddings, along with their document IDs, are aggregated and saved to disk as a precomputed index. This precomputation step enables fast future retrieval without re-encoding the entire corpus.

In [None]:
import json
import torch
import gc
import os
import ijson
from transformers import AutoTokenizer, AutoModel

class StellaEmbedding:
    def __init__(self, model_name="dunzhang/stella_en_400M_v5", device="cuda", max_length=256, use_half=False):
        self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
        self.model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
        self.device = device
        self.model.to(self.device)
        if use_half:
            self.model.half()
        self.model.eval()
        self.max_length = max_length

    @torch.no_grad()
    def encode_batch(self, texts, batch_size=8):
        all_embs = []
        for i in range(0, len(texts), batch_size):
            batch = texts[i:i+batch_size]
            encoded = self.tokenizer(batch, padding=True, truncation=True,
                                       max_length=self.max_length, return_tensors='pt')
            for k in encoded:
                encoded[k] = encoded[k].to(self.device)
            outputs = self.model(**encoded)
            token_embeddings = outputs.last_hidden_state  # [B, seq_len, hidden_dim]
            attention_mask = encoded['attention_mask'].unsqueeze(-1).float()  # [B, seq_len, 1]
            sum_emb = torch.sum(token_embeddings * attention_mask, dim=1)
            denom = torch.clamp(attention_mask.sum(dim=1), min=1e-9)
            embs = sum_emb / denom
            all_embs.append(embs.cpu())
        return torch.cat(all_embs, dim=0)

def stream_corpus(json_path, chunk_size=1000):
    """
    Use ijson to stream parse corpus JSON files, returning a list of chunk_size individual (doc_id, doc) tuples at a time.
    JSON structure ：{ doc_id1: { "title": ..., "text": ...}, doc_id2: { ... }, ... }
    """
    batch = []
    with open(json_path, 'r', encoding='utf-8') as f:
        for doc_id, doc in ijson.kvitems(f, ''):
            batch.append((doc_id, doc))
            if len(batch) >= chunk_size:
                yield batch
                batch = []
        if batch:
            yield batch

def build_index(corpus_path, index_path="precomputed_index.pt", chunk_size=1000, batch_size=8):
    embedder = StellaEmbedding(model_name="dunzhang/stella_en_400M_v5", device="cuda", max_length=256, use_half=False)
    all_doc_ids = []
    embeddings_list = []
    total_docs = 0

    for batch in stream_corpus(corpus_path, chunk_size=chunk_size):
        doc_ids = [doc_id for doc_id, _ in batch]
        texts = []
        for doc_id, doc in batch:
            title = doc.get("title", "")
            text = doc.get("text", "")
            texts.append(title + " " + text)

        batch_embs = embedder.encode_batch(texts, batch_size=batch_size)
        embeddings_list.append(batch_embs)
        all_doc_ids.extend(doc_ids)
        total_docs += len(batch)
        print(f"Processed {total_docs} documents")
        del doc_ids, texts, batch_embs
        gc.collect()

    # Merge vectors of all blocks
    all_embeddings = torch.cat(embeddings_list, dim=0)  # shape: [N, hidden_dim]
    print(f"Total embeddings shape: {all_embeddings.shape}")
    # Saving precomputed indexes to disk
    torch.save({"doc_ids": all_doc_ids, "embeddings": all_embeddings}, index_path)
    print(f"Precomputed index saved to {index_path}")

if __name__ == "__main__":
    corpus_path = "/content/drive/MyDrive/corpus_dict.json"
    output_index_path = "precomputed_index.pt"
    build_index(corpus_path, index_path=output_index_path, chunk_size=1000, batch_size=8)


Some weights of the model checkpoint at dunzhang/stella_en_400M_v5 were not used when initializing NewModel: {'new.pooler.dense.weight', 'new.pooler.dense.bias'}
- This IS expected if you are initializing NewModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing NewModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Processed 1000 documents
Processed 2000 documents
Processed 3000 documents
Processed 4000 documents
Processed 5000 documents
Processed 6000 documents
Processed 7000 documents
Processed 8000 documents
Processed 9000 documents
Processed 10000 documents
Processed 11000 documents
Processed 12000 documents
Processed 13000 documents
Processed 14000 documents
Processed 15000 documents
Processed 16000 documents
Processed 17000 documents
Processed 18000 documents
Processed 19000 documents
Processed 20000 documents
Processed 21000 documents
Processed 22000 documents
Processed 23000 documents
Processed 24000 documents
Processed 25000 documents
Processed 26000 documents
Processed 27000 documents
Processed 28000 documents
Processed 29000 documents
Processed 30000 documents
Processed 31000 documents
Processed 32000 documents
Processed 33000 documents
Processed 34000 documents
Processed 35000 documents
Processed 36000 documents
Processed 37000 documents
Processed 38000 documents
Processed 39000 docum

<!-- 基于预计算索引的检索search_precomputed_index.py -->
## Searching the Precomputed Index

This part loads the precomputed index file, which contains the document IDs and their corresponding embeddings. For a given query string, it uses the same StellaEmbedding class to encode the query into an embedding. Then, it computes the cosine similarity between the query embedding and all document embeddings, and selects the top‑k most similar documents using torch.topk. The search function returns a list of dictionaries containing each document’s ID and its similarity score.

In [6]:
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel

class StellaEmbedding:
    def __init__(self, model_name="dunzhang/stella_en_400M_v5", device="cuda", max_length=256, use_half=False):
        self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
        self.model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
        self.device = device
        self.model.to(self.device)
        if use_half:
            self.model.half()
        self.model.eval()
        self.max_length = max_length

    @torch.no_grad()
    def encode_batch(self, texts, batch_size=8):
        all_embs = []
        for i in range(0, len(texts), batch_size):
            batch = texts[i:i+batch_size]
            encoded = self.tokenizer(batch, padding=True, truncation=True,
                                       max_length=self.max_length, return_tensors='pt')
            for k in encoded:
                encoded[k] = encoded[k].to(self.device)
            outputs = self.model(**encoded)
            token_embeddings = outputs.last_hidden_state
            attention_mask = encoded['attention_mask'].unsqueeze(-1).float()
            sum_emb = torch.sum(token_embeddings * attention_mask, dim=1)
            denom = torch.clamp(attention_mask.sum(dim=1), min=1e-9)
            embs = sum_emb / denom
            all_embs.append(embs.cpu())
        return torch.cat(all_embs, dim=0)

def cosine_scores(query_emb, doc_embs):
    query_emb = query_emb.unsqueeze(0)  # [1, D]
    scores = F.cosine_similarity(query_emb, doc_embs, dim=-1)
    return scores

def search_precomputed_index(index_path, query_text, top_k=5, batch_size=8):
    # Load pre-calculated index
    index_data = torch.load(index_path)
    doc_ids = index_data["doc_ids"]
    embeddings = index_data["embeddings"]  # shape: [N, hidden_dim]

    # Code Search
    embedder = StellaEmbedding(model_name="dunzhang/stella_en_400M_v5", device="cuda", max_length=256, use_half=False)
    query_emb = embedder.encode_batch([query_text], batch_size=batch_size)[0]

    # Calculate the cosine similarity of all documents to the query
    scores = cosine_scores(query_emb, embeddings)

    # Get top_k best results
    top_scores, top_indices = torch.topk(scores, top_k, largest=True, sorted=True)
    top_scores = top_scores.tolist()
    top_indices = top_indices.tolist()

    results = []
    for score, idx in zip(top_scores, top_indices):
        results.append({"doc_id": doc_ids[idx], "similarity": score})
    return results

if __name__ == "__main__":
    index_path = "precomputed_index.pt"
    # sample question
    query = "What is a traditional breakfast in Mexico City?"
    top_k = 5
    results = search_precomputed_index(index_path, query, top_k=top_k, batch_size=8)

    print("\nQ: " + query)
    print("Top-K results from precomputed index:")
    for rank, res in enumerate(results, start=1):
        print(f"{rank}. DocID={res['doc_id']}, similarity={res['similarity']:.4f}")


Some weights of the model checkpoint at dunzhang/stella_en_400M_v5 were not used when initializing NewModel: {'new.pooler.dense.weight', 'new.pooler.dense.bias'}
- This IS expected if you are initializing NewModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing NewModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).



Q: What is a traditional breakfast in Mexico City?
Top-K results from precomputed index:
1. DocID=13444, similarity=0.6953
2. DocID=35507, similarity=0.6940
3. DocID=11626, similarity=0.6932
4. DocID=35545, similarity=0.6927
5. DocID=35517, similarity=0.6927


## Loading Document Previews
This part loads a CSV file containing document preview information (document ID, title, and the first sentence) into a Python dictionary. This allows the retrieval system to later display additional information (such as the document title and a brief preview) alongside the search results. This approach avoids reloading the full corpus and keeps the interactive session lightweight.

In [10]:
import csv
index_path = "precomputed_index.pt"
preview_csv = "ids_titles_sentences.csv"

# Load CSV content into dictionary
doc_previews = {}
with open(preview_csv, 'r', encoding='utf-8') as f:
    reader = csv.DictReader(f)
    for row in reader:
        doc_previews[row['id']] = {
            'title': row['title'],
            'first_sentence': row['first_sentence']
        }

# # Retrieve and display results
# while True:
#     query = "What is a traditional breakfast in Mexico City?"
#     top_k = 5
#     results = search_precomputed_index(index_path, query, top_k=top_k, batch_size=8)

#     print("\nQ:", query)
#     print("Top-K results from precomputed index:")
#     for rank, res in enumerate(results, start=1):
#         doc_id = res['doc_id']
#         similarity = res['similarity']
#         title = doc_previews.get(doc_id, {}).get('title', 'N/A')
#         first_sentence = doc_previews.get(doc_id, {}).get('first_sentence', 'N/A')
#         print(f"{rank}. Similarity={similarity:.4f}, DocID={doc_id}, Name={title}, Preview={first_sentence}")
#     break  # Add break here to prevent infinite loops

In [19]:
# Filter out warnings about unused weights during model initialization
from transformers import logging
logging.set_verbosity_error()

## Interactive Search Session
This interactive loop allows a user to repeatedly input queries and see the top‑k search results from the precomputed index. For each query, the system encodes the query, retrieves the most similar documents, and then displays for each result the document ID, similarity score, title, and a preview (first sentence). The session supports an exit command (typing "exit", "quit", or an empty input) so the user can terminate the interactive demo gracefully. All these steps use the precomputed index and the document preview dictionary, ensuring that the response time is very fast without needing to re-run lengthy model inference.

In [26]:
print("Interactive Search Session (type 'exit' or leave empty to quit)")
while True:
    query = input("Q: ").strip()
    if query.lower() in ["exit", "quit", ""]:
        print("Exiting interactive session.")
        break
    top_k = 5
    results = search_precomputed_index(index_path, query, top_k=top_k, batch_size=8)

    print("\nTop-K results from precomputed index:")
    for rank, res in enumerate(results, start=1):
        doc_id = res['doc_id']
        similarity = res['similarity']
        title = doc_previews.get(doc_id, {}).get('title', 'N/A')
        first_sentence = doc_previews.get(doc_id, {}).get('first_sentence', 'N/A')
        print(f"{rank}. Similarity={similarity:.4f}, DocID={doc_id}, Name={title}, Preview={first_sentence}")
    print("\n" + "-"*80)

Interactive Search Session (type 'exit' or leave empty to quit)
Q: What are the top tourist attractions in Rome?

Top-K results from precomputed index:
1. Similarity=0.7515, DocID=2981, Name=Attractions, Preview=# AttractionsFor information on attractions at specific locations check out articles of nearby places or destinations further afield
2. Similarity=0.7506, DocID=1781, Name=Ancient Rome tour, Preview=# Ancient Rome tourThe Forum Romanum tour is a walking tour around the Colosseo district in Rome, which was the centre of the Roman Empire
3. Similarity=0.7466, DocID=46425, Name=ROM, Preview=# RomeFor other places with the same name, see Rome (disambiguation)
4. Similarity=0.7444, DocID=46962, Name=Records, Preview=# AttractionsFor information on attractions at specific locations check out articles of nearby places or destinations further afield
5. Similarity=0.7437, DocID=48010, Name=Rome, Preview=# RomeFor other places with the same name, see Rome (disambiguation)

--------------

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive
