In [1]:
# pip install python-terrier fast-forward-indexes # torch==1.13.1 (version too old)

In [2]:
import logging

logging.basicConfig(level=logging.INFO)

In [4]:
from pathlib import Path
from fast_forward import OnDiskIndex, Mode, Ranking
from fast_forward.encoder import TCTColBERTQueryEncoder

encoder = TCTColBERTQueryEncoder("castorini/tct_colbert-msmarco")
ff_index = OnDiskIndex.load(
    Path("../ff_msmarco-v1-passage.tct_colbert.h5"), encoder, Mode.MAXP
)

100%|██████████| 8841823/8841823 [00:15<00:00, 581160.58it/s]


In [5]:
import ir_datasets

dataset = ir_datasets.load("msmarco-passage/trec-dl-2019/judged")
r = Ranking.from_file(
    Path("msmarco-passage-test2019-sparse10000.txt"),
    {q.query_id: q.text for q in dataset.queries_iter()},
)

[INFO] Please confirm you agree to the MSMARCO data usage agreement found at <http://www.msmarco.org/dataset.aspx>
[INFO] [starting] https://trec.nist.gov/data/deep/2019qrels-pass.txt
[INFO] [finished] https://trec.nist.gov/data/deep/2019qrels-pass.txt: [00:00] [187kB] [460kB/s]
[INFO] [starting] https://msmarco.z22.web.core.windows.net/msmarcoranking/msmarco-test2019-queries.tsv.gz
[INFO] [finished] https://msmarco.z22.web.core.windows.net/msmarcoranking/msmarco-test2019-queries.tsv.gz: [00:00] [4.28kB] [24.5MB/s]
  df = pd.read_csv(


In [6]:
# standard re-ranking, probably takes a few min
ff_out = ff_index(r.cut(5000))

INFO:fast_forward.index:computed scores in 144.32676931499964 seconds


In [7]:
# re-ranking with early stopping
ff_out_es = ff_index(
    r.cut(5000),
    early_stopping=10,
    early_stopping_alpha=0.2,
    early_stopping_intervals=(800, 5000),
)

INFO:fast_forward.index:depth 800: 16 queries left
INFO:fast_forward.index:depth 5000: 9 queries left
INFO:fast_forward.index:computed scores in 157.70563292299994 seconds


In [8]:
from ir_measures import calc_aggregate, AP, RR
from fast_forward.util import to_ir_measures

print(
    "no re-ranking:\n",
    calc_aggregate(
        [AP(rel=2) @ 1000, RR(rel=2) @ 10], dataset.qrels_iter(), to_ir_measures(r)
    ),
    "\n\nstandard re-ranking:\n",
    calc_aggregate(
        [AP(rel=2) @ 1000, RR(rel=2) @ 10],
        dataset.qrels_iter(),
        to_ir_measures(r.interpolate(ff_out, 0.2)),
    ),
    "\n\nre-ranking with early stopping:\n",
    calc_aggregate(
        [RR(rel=2) @ 10],
        dataset.qrels_iter(),
        to_ir_measures(r.interpolate(ff_out_es, 0.2)),
    ),
)

no re-ranking:
 {AP(rel=2)@1000: 0.30128706043561426, RR(rel=2)@10: 0.7024178663713547} 

standard re-ranking:
 {AP(rel=2)@1000: 0.45949573660757204, RR(rel=2)@10: 0.901937984496124} 

re-ranking with early stopping:
 {RR(rel=2)@10: 0.901937984496124}
