In [1]:
import json
import numpy as np
import sys
import os
from tqdm import tqdm
from pyserini.search.lucene import LuceneSearcher
from pyserini.search.faiss import FaissSearcher
import pytrec_eval
from pyserini.encode import TctColBertQueryEncoder
# from feature_extractor import FeatureExtractor

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# --- Configuration ---
OUTPUT_FILE = "../MABhybrid/data/bandit_data_train.jsonl"
QUERIES_FILE = "../MABhybrid/data/msmarco-train-queries.tsv"
QRELS_FILE = "../MABhybrid/data/qrels.train.tsv"
SAMPLE_SIZE = 50000  # Number of queries to process
TOP_K = 1000         # Depth of initial retrieval
ARMS = [0.0, 0.25, 0.5, 0.75, 1.0] # Alpha values: 0.0=Dense, 1.0=Sparse

In [3]:
def load_queries(path, limit=None):
    """Loads queries from a TSV file (qid \t text)."""
    queries = []
    print(f"Loading queries from {path}...")
    with open(path, 'r', encoding='utf-8') as f:
        for i, line in enumerate(f):
            if limit and i >= limit: break
            qid, text = line.strip().split('\t')
            queries.append((qid, text))
    return queries

def load_qrels(path):
    """Loads qrels into a dictionary: {qid: {docid: relevance}}."""
    qrels = {}
    print(f"Loading qrels from {path}...")
    with open(path, 'r', encoding='utf-8') as f:
        for line in f:
            qid, _, docid, rel = line.strip().split('\t')
            if qid not in qrels: qrels[qid] = {}
            qrels[qid][docid] = int(rel)
    return qrels

def normalize_scores(hits):
    """Min-Max normalization of scores to [0, 1]."""
    if not hits: return {}
    scores = [h.score for h in hits]
    min_s = min(scores)
    max_s = max(scores)
    
    # Avoid division by zero if all scores are identical
    if max_s == min_s: 
        return {h.docid: 1.0 for h in hits}
        
    # Return dictionary {docid: normalized_score}
    norm_scores = {}
    for h in hits:
        norm_scores[h.docid] = (h.score - min_s) / (max_s - min_s)
    return norm_scores

# def calculate_ndcg(run_dict, qrels_dict, k=10):
#     """
#     Calculates NDCG@k using pytrec_eval.
#     run_dict: {docid: score}
#     qrels_dict: {docid: relevance}
#     """
#     if not qrels_dict: return 0.0
    
#     # pytrec_eval expects {qid: {docid: score}} structure
#     evaluator = pytrec_eval.RelevanceEvaluator({'q': qrels_dict}, {'ndcg_cut'})
#     results = evaluator.evaluate({'q': run_dict})
#     return results['q'][f'ndcg_cut_{k}']

# python
def calculate_ndcg(qid, run_dict, qrels_dict, k=10):
    """
    qid: str query id
    run_dict: {docid: score}
    qrels_dict: {docid: relevance}
    """
    if not qrels_dict:
        return 0.0

    # Ensure native Python types (pytrec_eval C extension requires them)
    run_cast = {str(d): float(s) for d, s in run_dict.items()}
    qrels_cast = {str(d): int(r) for d, r in qrels_dict.items()}

    metric = f'ndcg_cut_{k}'
    evaluator = pytrec_eval.RelevanceEvaluator({qid: qrels_cast}, {metric})
    results = evaluator.evaluate({qid: run_cast})
    return results[qid][metric]

In [5]:
import numpy as np
from pyserini.index.lucene import LuceneIndexReader

class FeatureExtractor:
    """
    Converts raw text queries into dense numerical vectors using Pyserini 
    for corpus statistics.
    """
    def __init__(self, index_path='msmarco-passage'):
        """
        Args:
            index_path (str): The name of the pre-built index (e.g., 'msmarco-passage')
                              or a path to a local Lucene index.
        """
        print(f"Initializing FeatureExtractor with index: {index_path}")
        # We must use IndexReader to access term statistics (df, cf), not LuceneSearcher
        self.reader = LuceneIndexReader.from_prebuilt_index(index_path)
        if not self.reader:
             # Fallback for local paths if it's not a prebuilt index name
            self.reader = LuceneIndexReader(index_path)
            
        # Get total number of documents (N) for IDF calculation
        # .stats() returns a dict like {'documents': 8841823, 'non_empty_documents': ...}
        self.N = self.reader.stats()['documents']

    def get_idf(self, term):
        """
        Calculates Inverse Document Frequency (IDF) for a single term.
        Formula: log( N / (df + 1) )
        """
        # get_term_counts returns (df, cf). We only need df (Document Frequency).
        # analyzer=None uses the default analyzer for the index (recommended).
        try:
            df, _ = self.reader.get_term_counts(term, analyzer=None)
        except:
            # Handle cases where term might cause Java encoding errors or not exist
            df = 0
            
        # Avoid division by zero or log(0) issues
        # Add 1 to df for smoothing
        return np.log(self.N / (df + 1)) if (df + 1) > 0 else 0.0

    def extract(self, query_text):
        """
        Extracts the 5-dimensional feature vector for a query.
        
        Feature Definition:
        1. Length: Number of tokens
        2. Max IDF: Rarity of the rarest word (keyword specificity)
        3. Avg IDF: Average rarity (information density)
        4. Question Flag: 1.0 if starts with Wh-word, else 0.0
        5. Bias: Constant 1.0 (Intercept)
        """
        # Simple whitespace tokenization 
        # (For production, consider matching the Pyserini analyzer's tokenization)
        tokens = query_text.lower().split()
        length = len(tokens)
        
        if length == 0:
            return np.array([0.0, 0.0, 0.0, 0.0, 1.0])
        
        # Compute IDFs
        idfs = [self.get_idf(t) for t in tokens]
        
        max_idf = max(idfs) if idfs else 0.0
        avg_idf = np.mean(idfs) if idfs else 0.0
        
        # Heuristic: Check for Wh-words to detect natural language questions
        question_starters = {'who', 'what', 'where', 'when', 'why', 'how', 'which'}
        is_question = 1.0 if tokens[0] in question_starters else 0.0
        
        # Construct and return the vector
        # [Length, MaxIDF, AvgIDF, QuestionFlag, Bias]
        return np.array([float(length), max_idf, avg_idf, is_question, 1.0])

In [11]:
# 1. Setup Pyserini Searchers
print("Initializing Sparse Searcher (BM25)...")
# Automatically downloads the pre-built index
sparse_searcher = LuceneSearcher.from_prebuilt_index('msmarco-passage')


Initializing Sparse Searcher (BM25)...


In [12]:
sparse_searcher.search("test")  # Warm-up call

[<io.anserini.search.ScoredDoc at 0x3257cc5a0 jclass=io/anserini/search/ScoredDoc jself=<LocalRef obj=0x30a7610e2 at 0x400a10a30>>,
 <io.anserini.search.ScoredDoc at 0x3257cc780 jclass=io/anserini/search/ScoredDoc jself=<LocalRef obj=0x30a7610ea at 0x400a10b50>>,
 <io.anserini.search.ScoredDoc at 0x3257cc8c0 jclass=io/anserini/search/ScoredDoc jself=<LocalRef obj=0x30a7610f2 at 0x400a10990>>,
 <io.anserini.search.ScoredDoc at 0x3257cc910 jclass=io/anserini/search/ScoredDoc jself=<LocalRef obj=0x30a7610fa at 0x400a108d0>>,
 <io.anserini.search.ScoredDoc at 0x3257cc960 jclass=io/anserini/search/ScoredDoc jself=<LocalRef obj=0x30a761102 at 0x400a10870>>,
 <io.anserini.search.ScoredDoc at 0x3257cc9b0 jclass=io/anserini/search/ScoredDoc jself=<LocalRef obj=0x30a76110a at 0x400a109f0>>,
 <io.anserini.search.ScoredDoc at 0x3257cca00 jclass=io/anserini/search/ScoredDoc jself=<LocalRef obj=0x30a76111a at 0x400a10af0>>,
 <io.anserini.search.ScoredDoc at 0x3257cd1d0 jclass=io/anserini/search/Scor

In [14]:

print("Initializing Dense Searcher (TCT-ColBERT)...")
# Automatically downloads encoder and index
encoder = TctColBertQueryEncoder('castorini/tct_colbert-msmarco')
dense_searcher = FaissSearcher.from_prebuilt_index(
    'msmarco-passage-tct_colbert-hnsw', 
    encoder
)
print("Dense Searcher initialized.")


Initializing Dense Searcher (TCT-ColBERT)...
Attempting to initialize prebuilt index msmarco-passage-tct_colbert-hnsw.
/Users/ativsc/.cache/pyserini/indexes/faiss-hnsw.msmarco-v1-passage.tct_colbert.20210112.be7119.6b7285a7f0163d1a547214396be20488 already exists, skipping download.
Initializing msmarco-v1-passage.tct_colbert.hnsw...
Dense Searcher initialized.


In [15]:
# 2. Setup Feature Extractor
# Reuse the sparse searcher's index reader for IDF stats
feature_extractor = FeatureExtractor('msmarco-passage')

# 3. Load Data
all_queries = load_queries(QUERIES_FILE, limit=None) # Load all first, then sample
all_qrels = load_qrels(QRELS_FILE)

# Randomly sample queries that actually have qrels
valid_queries = [q for q in all_queries if q[0] in all_qrels]
print(f"Found {len(valid_queries)} queries with judgments.")

# Shuffle and take 50k
np.random.seed(42)
indices = np.random.choice(len(valid_queries), min(SAMPLE_SIZE, len(valid_queries)), replace=False)
sampled_queries = [valid_queries[i] for i in indices]

Initializing FeatureExtractor with index: msmarco-passage
Loading queries from ../MABhybrid/data/msmarco-train-queries.tsv...
Loading qrels from ../MABhybrid/data/qrels.train.tsv...
Found 502939 queries with judgments.


In [7]:
# run in your notebook / REPL
print(type(feature_extractor.reader))
print(dir(feature_extractor.reader))
print([m for m in dir(feature_extractor.reader) if 'term' in m.lower() or 'doc' in m.lower()])

<class 'pyserini.index.lucene._base.LuceneIndexReader'>
['__class__', '__delattr__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__getstate__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__sizeof__', '__str__', '__subclasshook__', '__weakref__', 'analyze', 'compute_bm25_term_weight', 'compute_query_document_score', 'convert_collection_docid_to_internal_docid', 'convert_internal_docid_to_collection_docid', 'doc', 'doc_by_field', 'doc_contents', 'doc_raw', 'dump_documents_BM25', 'from_prebuilt_index', 'get_document_vector', 'get_postings_list', 'get_term_counts', 'get_term_positions', 'list_prebuilt_indexes', 'object', 'quantize_weights', 'reader', 'stats', 'terms', 'validate']
['__doc__', 'compute_bm25_term_weight', 'compute_query_document_score', 'convert_collection_docid_to_internal_docid', 'convert_internal_docid_to

In [16]:
print(f"Starting processing for {len(sampled_queries)} queries...")

with open(OUTPUT_FILE, 'w') as f_out:
    for qid, text in tqdm(sampled_queries):
        # A. Retrieve Sparse & Dense
        # print(text)
        # try:
        sparse_hits = sparse_searcher.search(text, k=TOP_K)
        dense_hits = dense_searcher.search(text, k=TOP_K)
        # except Exception as e:
        #     # Handle empty queries or encoding errors
        #     print(f"Error retrieving for query {qid}: {e}", file=sys.stderr)
        #     continue

        # B. Normalize Scores
        # Map: {docid: norm_score}
        sparse_dict = normalize_scores(sparse_hits)
        dense_dict = normalize_scores(dense_hits)
        
        # Union of all retrieved docs
        all_docs = set(sparse_dict.keys()) | set(dense_dict.keys())
        
        query_rewards = []
        
        # C. Fusion Loop (Calculate reward for each Arm)
        for alpha in ARMS:
            fused_scores = {}
            for docid in all_docs:
                s_score = sparse_dict.get(docid, 0.0)
                d_score = dense_dict.get(docid, 0.0)
                # Fusion Formula
                score = alpha * s_score + (1.0 - alpha) * d_score
                fused_scores[docid] = score
            
            # Sort top K for evaluation
            # Note: pytrec_eval handles sorting, but we can limit size for speed
            top_100 = dict(sorted(fused_scores.items(), key=lambda x: x[1], reverse=True)[:100])
            
            # Calculate Reward (NDCG@10)
            # reward = calculate_ndcg(top_100, all_qrels[qid], k=10)
            reward = calculate_ndcg(qid, top_100, all_qrels[qid], k=10)
            query_rewards.append(reward)

        # D. Feature Extraction
        features = feature_extractor.extract(text).tolist()

        # Save
        record = {
            "query_id": qid,
            "text": text,
            "features": features,
            "rewards": query_rewards,
            "optimal_arm": int(np.argmax(query_rewards))
        }
        f_out.write(json.dumps(record) + '\n')

print(f"Done! Generated data saved to {OUTPUT_FILE}")


Starting processing for 50000 queries...


  0%|          | 100/50000 [10:02<83:30:51,  6.03s/it]


KeyboardInterrupt: 