In [None]:
import os
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
os.environ['CUDA_HOME'] = '/root/miniconda3/envs/colbert'
os.environ['COLBERT_LOAD_TORCH_EXTENSION_VERBOSE'] = 'True'
import csv
import re
import torch
from colbert import Indexer, Searcher
from colbert.infra import Run, RunConfig, ColBERTConfig
from colbert.data import Queries, Collection


In [None]:
def read_queries(query_file):
    queries = []
    with open(query_file, 'r', encoding='utf-8') as file:
        reader = csv.reader(file, delimiter='\t')
        for row in reader:
            query_id, query_text, rewrite_text, condense_text = row
            queries.append((query_id, query_text, rewrite_text, condense_text))
    return queries

In [None]:
def perform_search(searcher, queries, split_index, split_size, results_dir, top_k, combine_order, threshold, boost_factor):
    for query_id, query_text, rewrite_text, condense_text in queries:
        results = searcher.coverage_search_for_batch(condense_text, rewrite_text, k=top_k, threshold=threshold, boost_factor=boost_factor, combine_order=combine_order)
        topk_pids = [passage_id for passage_id, _, _ in zip(*results)]
        original_scores = results[2]
        ori_result_file = os.path.join(results_dir, f"{query_id}.txt")
        with open(ori_result_file, 'a', encoding='utf-8') as ori_file:
            for rank, (passage_id, score) in enumerate(zip(topk_pids, results[2]), start=1):
                global_passage_id = split_index * split_size + passage_id
                ori_file.write(f"{global_passage_id}\t{rank}\t{score}\n")
            # print(f"query {query_id} on split {split_index + 1}")


In [None]:
def search_and_save_intermediate_results(split_dir, checkpoint_path, experiment_name_prefix, exp_root, queries_list, results_dir_list, top_k, combine_order, threshold, boost_factor, kmeans_niters=4, nbits=2, split_size=2000000):
    print(f"split_dir: {split_dir}") 
    split_files = sorted(
    [os.path.join(split_dir, f) for f in os.listdir(split_dir) if f.endswith('.tsv')],
    key=lambda x: int(os.path.basename(x).split('_')[-1].split('.')[0])
    )

    for file in split_files:
        print(file)
        
    for split_index in range(len(split_files)):
        print(f"split_dir_1: {split_dir}")
        experiment_name = f"{experiment_name_prefix}_split_{split_index + 1}"
        index_name = f"{experiment_name}.nbits={nbits}"
        print(f"index_name: {index_name}")  
        print(f"exp_name: {experiment_name}")
        with Run().context(RunConfig(nranks=1, experiment=experiment_name)):
            config = ColBERTConfig(nbits=nbits, root=exp_root, kmeans_niters=kmeans_niters)
            searcher = Searcher(index=index_name,
                                checkpoint=checkpoint_path,
                                collection=os.path.join(split_dir, f"{experiment_name_prefix}_{split_index + 1}.new.tsv"),
                                config=config)
            for queries, results_dir in zip(queries_list, results_dir_list):
                perform_search(searcher, queries, split_index, split_size, results_dir, top_k, combine_order, threshold, boost_factor)

In [None]:
def merge_and_save_final_results(queries, results_dir, top_k=100):
    final_results = {}
    for query_index, (query_id, _, _, _) in enumerate(queries):
        result_file = os.path.join(results_dir, f"{query_id}.txt")
        with open(result_file, 'r', encoding='utf-8') as file:
            lines = file.readlines()
        results = [(line.strip().split('\t')[0], int(line.strip().split('\t')[1]), float(line.strip().split('\t')[2])) for line in lines]
        results = sorted(results, key=lambda x: x[2], reverse=True)
        unique_results = []
        seen_doc_ids = set()
        for result in results:
            if result[0] not in seen_doc_ids:
                unique_results.append(result)
                seen_doc_ids.add(result[0])
            if len(unique_results) == top_k * 3: 
                break
        if len(unique_results) < top_k * 3:
            for result in results:
                if result[0] not in seen_doc_ids:
                    unique_results.append(result)
                    seen_doc_ids.add(result[0])
                if len(unique_results) == top_k * 3:
                    break
        final_results[query_id] = unique_results[:top_k * 3]
    final_result_file = os.path.join(results_dir, 'final_results.txt')
    with open(final_result_file, 'w', encoding='utf-8') as file:
        for query_id, results in final_results.items():
            for rank, (global_passage_id, _, passage_score) in enumerate(results, start=1):
                file.write(f"{query_id}\tQ0\t{global_passage_id}\t{rank}\t{passage_score}\tCR\n")

In [None]:
def read_id_mapping(collection_file):
    id_mapping = {}
    with open(collection_file, 'r', encoding='utf-8') as file:
        reader = csv.reader(file, delimiter='\t')
        for row in reader:
            if len(row) != 2: 
                continue
            original_id, numeric_id = row
            if not numeric_id: 
                continue
            id_mapping[numeric_id] = original_id
    return id_mapping

In [None]:
def map_ids_to_original(results_file, id_mapping, output_file):
    with open(results_file, 'r', encoding='utf-8') as infile, open(output_file, 'w', encoding='utf-8') as outfile:
        for line in infile:
            parts = line.strip().split('\t')
            query_id, global_passage_id, rank, score, method = parts[0], parts[2], parts[3], parts[4], parts[5]
            original_id = id_mapping.get(global_passage_id, global_passage_id)
            outfile.write(f"{query_id}\tQ0\t{original_id}\t{rank}\t{score}\t{method}\n")

In [None]:
def main(split_directory, checkpoint_path, experiment_name_prefix, exp_root, queries_dir_list, results_dir_base, default_top_k, default_threshold, default_boost_factor, default_combine_order):
    queries_list = []
    results_dir_list = []

    for query_file in queries_dir_list:
        if os.path.isfile(query_file):
            queries = read_queries(query_file)
            queries_list.append(queries)
            print(f"read query file: {query_file}")
        else:
            print(f" {query_file} does not exist.")

    param_combinations = [
        {'top_k': k, 'threshold': default_threshold, 'boost_factor': default_boost_factor, 'combine_order': default_combine_order} for k in top_k_values if k != default_top_k
    ] + [
        {'top_k': default_top_k, 'threshold': t, 'boost_factor': default_boost_factor, 'combine_order': default_combine_order} for t in thresholds if t != default_threshold
    ] + [
        {'top_k': default_top_k, 'threshold': default_threshold, 'boost_factor': b, 'combine_order': default_combine_order} for b in boost_factors if b != default_boost_factor
    ] + [
        {'top_k': default_top_k, 'threshold': default_threshold, 'boost_factor': default_boost_factor, 'combine_order': c} for c in combine_orders if c != default_combine_order
    ] + [
        {'top_k': default_top_k, 'threshold': default_threshold, 'boost_factor': default_boost_factor, 'combine_order': default_combine_order} 
    ] 

    start_index = 0

    for i, params in enumerate(param_combinations[start_index:], start=start_index):
        results_dir = f"{results_dir_base}{i + 1}"
        try:
            os.makedirs(results_dir, exist_ok=True)
        except Exception as e:
            print(f"error creating directory {results_dir}: {e}")
        
        config_file = os.path.join(results_dir, 'config.txt')
        with open(config_file, 'w') as f:
            f.write(f"top_k: {params['top_k']}\n")
            f.write(f"threshold: {params['threshold']}\n")
            f.write(f"boost_factor: {params['boost_factor']}\n")
            f.write(f"combine_order: {params['combine_order']}\n")
        
        search_and_save_intermediate_results(
            split_directory, checkpoint_path, experiment_name_prefix, exp_root, 
            queries_list, [results_dir], 
            params['top_k'], params['combine_order'], params['threshold'], params['boost_factor']
        )
        for queries in queries_list:
            merge_and_save_final_results(queries, results_dir, top_k=100)
            print(f"result saved: {results_dir}")


In [None]:
    split_directory = "/split_collection_your_dataset" # {id}/t{doc_text} The IDs in each split start from 0.
    checkpoint_path = "/ColBERT/colbertv2.0"
    experiment_name_prefix = "your_dataset"
    exp_root = "/experiments"
    queries_dir_list = ["/query/query.tsv"] # {query_id}/t{raw_query}/t{full-rewrite query}/t{condensed-rewrite query}
    results_dir_base = "/result/results_dir_" 
    collection_file = "/your_dataset.intmapping"  # {id}/t{realid}

    default_top_k = 1000
    default_threshold = 0.5
    default_boost_factor = 0.05
    default_combine_order = 'original_first'

    top_k_values = [1000]
    thresholds = [0.5]
    boost_factors = [0.05]
    combine_orders = ['original_first']

In [None]:
    main(split_directory, checkpoint_path, experiment_name_prefix, exp_root, queries_dir_list, results_dir_base, default_top_k, default_threshold, default_boost_factor, default_combine_order)