-
Notifications
You must be signed in to change notification settings - Fork 0
/
search.py
185 lines (166 loc) · 7.98 KB
/
search.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
# -*- coding: UTF-8 -*-
"""
@Project: PPRF
@File: search.py
@Author: Rosenberg
@Date: 2022/12/10 9:17
@Documentation:
...
"""
import os
import os.path
from multiprocessing import cpu_count
from typing import List, Mapping, Tuple, Union
from jsonargparse import CLI
from pyserini.query_iterator import get_query_iterator, TopicsFormat
from tqdm import tqdm
from source import BatchSearchResult, DEFAULT_CACHE_DIR
from source.eval import evaluate
from source.utils import QUERY_NAME_MAPPING
from source.utils.output import OutputWriter
from source.utils.pseudo import PseudoQuerySearcher
def search(
topic_name: str = 'msmarco-passage-dev-subset',
pseudo_name: str = 'msmarco_v1_passage_doc2query-t5_expansions_-1',
pseudo_index_dir: str = None,
num_pseudo_queries: int = 8,
num_pseudo_return_hits: int = 1000,
pseudo_encoder_name: Union[str, List[str]] = "lucene",
pseudo_doc_index: Union[str, List[str]] = 'msmarco-v1-passage-full',
pseudo_prf_depth: int = 0,
pseudo_prf_method: str = 'avg',
pseudo_rocchio_alpha: float = 0.9,
pseudo_rocchio_beta: float = 0.1,
pseudo_rocchio_gamma: float = 0.1,
pseudo_rocchio_topk: int = 3,
pseudo_rocchio_bottomk: int = 0,
query_index: str = None,
query_k1: float = None,
query_b: float = None,
query_rm3: bool = False,
query_rocchio: bool = False,
query_rocchio_use_negative: bool = False,
num_return_hits: int = 1000,
max_passage: bool = False,
max_passage_hits: int = 1000,
max_passage_delimiter: str = '#',
output_path: str = os.path.join(DEFAULT_CACHE_DIR, "runs"),
reference_name: str = None,
metrics: List[str] = None,
print_result: bool = True,
use_cache: bool = True,
threads: int = cpu_count(),
batch_size: int = cpu_count(),
device: str = "cpu",
) -> Tuple[BatchSearchResult, BatchSearchResult, Mapping[str, float]]:
"""
:param topic_name: Name of topics.
:param query_b: bm25 b for original query search
:param query_k1: bm25 k1 for original query search
:param query_rm3: whether the rm3 algorithm used for the first stage search.
:param query_rocchio: whether the rocchio algorithm used for the first stage search.
:param query_rocchio_use_negative: whether the rocchio algorithm with negative used for the first stage search.
:param pseudo_name: index name of the candidate pseudo queries
:param pseudo_index_dir: index path to the candidate pseudo queries.
:param num_pseudo_queries: how many pseudo query used for second stage
:param query_index: the index original query to perform sparse retrieval
:param num_pseudo_return_hits: Number of hits to return by each pseudo query.
:param pseudo_encoder_name: Path to query encoder pytorch checkpoint or hgf encoder model name
:param pseudo_prf_depth: Specify how many passages are used for PRF, 0: Simple retrieval with no PRF, > 0: perform PRF
:param pseudo_prf_method: Choose PRF methods, avg or rocchio
:param pseudo_rocchio_alpha: The alpha parameter to control the contribution from the query vector
:param pseudo_rocchio_beta: The beta parameter to control the contribution from the average vector of the positive PRF passages
:param pseudo_rocchio_gamma: The gamma parameter to control the contribution from the average vector of the negative PRF passages
:param pseudo_rocchio_topk: Set topk passages as positive PRF passages for rocchio
:param pseudo_rocchio_bottomk: Set bottomk passages as negative PRF passages for rocchio, 0: do not use negatives prf passages.
:param pseudo_doc_index: the index of the candidate documents
:param num_return_hits: how many hits will be returned
:param max_passage: Select only max passage from document.
:param max_passage_hits: Final number of hits when selecting only max passage.
:param max_passage_delimiter: Delimiter between docid and passage id.
:param reference_name: Reference name left for the evaluation of p-value
:param metrics: metrics that play evaluation on.
:param print_result: whether print the evaluation result.
:param use_cache: whether we use cached score
:param threads: maximum threads to use during search
:param batch_size: batch size used for the batch search.
:param device: the device the whole search procedure will on
:param output_path: the path where the run file will be outputted
"""
if pseudo_name is not None:
if pseudo_index_dir is not None:
raise ValueError("Can not specify both pseudo_name and pseudo_index")
else:
pseudo_index_dir = os.path.join(DEFAULT_CACHE_DIR, 'indexes', pseudo_name)
elif pseudo_index_dir is None:
raise ValueError("At least specify pseudo_name or pseudo_index")
searcher = PseudoQuerySearcher(
pseudo_index_dir, pseudo_doc_index,
pseudo_encoder_name=pseudo_encoder_name,
pseudo_prf_depth=pseudo_prf_depth,
pseudo_prf_method=pseudo_prf_method,
pseudo_rocchio_alpha=pseudo_rocchio_alpha,
pseudo_rocchio_beta=pseudo_rocchio_beta,
pseudo_rocchio_gamma=pseudo_rocchio_gamma,
pseudo_rocchio_topk=pseudo_rocchio_topk,
pseudo_rocchio_bottomk=pseudo_rocchio_bottomk,
query_index=query_index,
query_k1=query_k1,
query_b=query_b,
query_rm3=query_rm3,
query_rocchio=query_rocchio,
query_rocchio_use_negative=query_rocchio_use_negative,
device=device
)
if topic_name not in QUERY_NAME_MAPPING:
raise ValueError(f"{topic_name} is current not supported.")
query_iterator = get_query_iterator(QUERY_NAME_MAPPING[topic_name], TopicsFormat.DEFAULT)
if type(pseudo_encoder_name) is str:
pseudo_encoder_full_name = pseudo_encoder_name.split('/')[-1]
elif type(pseudo_encoder_name) is list:
pseudo_encoder_full_name = "hybrid"
else:
raise ValueError("Unexpected type of pseudo_encoder_name.")
if pseudo_prf_depth is not None:
pseudo_encoder_full_name += f"-{pseudo_prf_method}-{pseudo_prf_depth}"
run_name = f"run.{pseudo_name}.{topic_name}.{num_pseudo_queries}.{pseudo_encoder_full_name}.txt"
run_path = os.path.join(output_path, run_name)
log_name = f"run.{pseudo_name}.{topic_name}.{num_pseudo_queries}.{pseudo_encoder_full_name}.log"
log_path = os.path.join(output_path, "log", log_name)
output_writer = OutputWriter(
run_path,
log_path=log_path,
max_hits=num_return_hits,
tag=output_path[:-4],
topics=query_iterator.topics,
use_max_passage=max_passage,
max_passage_delimiter=max_passage_delimiter,
max_passage_hits=max_passage_hits
)
batch_queries, batch_qids = list(), list()
query_hits, pseudo_hits, queries_ids = dict(), dict(), dict()
for index, (query_id, text) in enumerate(tqdm(query_iterator)):
batch_queries.append(text), batch_qids.append(str(query_id))
if (index + 1) % batch_size == 0 or index == len(query_iterator.topics) - 1:
batch_query_hits, batch_pseudo_hits = searcher.batch_search(
batch_queries, batch_qids,
num_pseudo_queries=num_pseudo_queries,
num_pseudo_return_hits=num_pseudo_return_hits,
use_cache=use_cache,
threads=threads,
)
query_hits.update(batch_query_hits), pseudo_hits.update(batch_pseudo_hits)
queries_ids.update({qid: query for query, qid in zip(batch_queries, batch_qids)})
batch_queries.clear(), batch_qids.clear()
with output_writer:
output_writer.write(query_hits, pseudo_hits, queries_ids)
metrics = evaluate(
topic_name=topic_name,
path_to_candidate=run_path,
reference_name=reference_name,
metrics=metrics,
print_result=print_result,
)
return query_hits, pseudo_hits, metrics
if __name__ == '__main__':
CLI(search)