<a href="https://colab.research.google.com/github/agasti-mhatre/CS5610-FinalProject-Frontend/blob/main/reranking.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Reranking Retrieval Results

In this notebook, you will continue using the [Pyserini](http://pyserini.io/) library's indexing and retrieval models.  This time, however, you will get an initial set of retrieval results and then write your own reranking code to try to move relevant documents higher in the list.

As before, we start by installing the python interface. Since it calls the underlying Lucene search engine, which is written in Java, we make sure we point to an appropriate Java installation. If like Colab you don't have Java 21, uncomment the following code and run it, or whatever makes sense for your platform.

In [2]:
## Uncomment the following code to install Java 21 on Colab
!apt-get install openjdk-21-jre-headless -qq > /dev/null
import os
os.environ["JAVA_HOME"] = "/usr/lib/jvm/java-21-openjdk-amd64"
!update-alternatives --set java /usr/lib/jvm/java-21-openjdk-amd64/bin/java
!java -version

openjdk version "21.0.6" 2025-01-21
OpenJDK Runtime Environment (build 21.0.6+7-Ubuntu-122.04.1)
OpenJDK 64-Bit Server VM (build 21.0.6+7-Ubuntu-122.04.1, mixed mode, sharing)


In [3]:
!pip install pyserini
# You can change this to gpu if you have one.
# It's a pyserini dependency, but we won't need it until the next assignment.
!pip install faiss-cpu

Collecting pyserini
  Downloading pyserini-0.44.0.tar.gz (195.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m195.3/195.3 MB[0m [31m6.0 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting pyjnius>=1.6.0 (from pyserini)
  Downloading pyjnius-1.6.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (10 kB)
Collecting onnxruntime>=1.8.1 (from pyserini)
  Downloading onnxruntime-1.21.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (4.5 kB)
Collecting tiktoken>=0.4.0 (from pyserini)
  Downloading tiktoken-0.9.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.7 kB)
Collecting coloredlogs (from onnxruntime>=1.8.1->pyserini)
  Downloading coloredlogs-15.0.1-py2.py3-none-any.whl.metadata (12 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from t

We initialize the searcher with a pre-built index for the Robust04 collection, which Pyserini will automatically download if it hasn't already. Note that the index takes up 1.6GB of disk.

In [4]:
from pyserini.search.lucene import LuceneSearcher

searcher = LuceneSearcher.from_prebuilt_index('robust04')

Downloading index at https://rgw.cs.uwaterloo.ca/pyserini/indexes/lucene/lucene-inverted.disk45.20240803.36f7e3.tar.gz...


lucene-inverted.disk45.20240803.36f7e3.tar.gz: 1.66GB [00:38, 46.8MB/s]                            


Now we can search for a query and inspect the results:

In [5]:
hits = searcher.search('black bear attacks', 1000)

# Prints the first 10 hits
for i in range(0, 10):
    print(f'{i+1:2} {hits[i].docid:15} {hits[i].score:.5f}')

 1 LA092790-0015   7.06680
 2 LA081689-0039   6.89020
 3 FBIS4-16530     6.61630
 4 LA102589-0076   6.46450
 5 FT932-15491     6.25090
 6 FBIS3-12276     6.24630
 7 LA091090-0085   6.17030
 8 FT922-13519     6.04270
 9 LA052790-0205   5.94060
10 LA103089-0041   5.90650


The `IndexReaderUtils` class provides various methods to read the index directly. For example, we can fetch a raw document from the index given its `docid`:

In [6]:
from pyserini.index import LuceneIndexReader
from IPython.core.display import display, HTML

reader = LuceneIndexReader.from_prebuilt_index('robust04')

doc = reader.doc('LA092790-0015').raw()
display(HTML('<div style="font-family: Times New Roman; padding-bottom:10px">' + doc + '</div>'))

Note that the result is exactly the same as displaying the hit contents above. Given the raw text, we can obtain its analyzed form (i.e., tokenized, stemmed, stopwords removed, etc.). Here we show the first ten tokens:

In [7]:
analyzed = reader.analyze(doc)
analyzed[0:10]

['date',
 'p',
 'septemb',
 '27',
 '1990',
 'thursdai',
 'ventura',
 'counti',
 'edit',
 'p']

## Retrieving Initial Ranked Lists

We can load some standard evaluation sets such as Robust04, which contains 250 queries, or "topics" as the TREC conferences call them.

In [8]:
from pyserini.search import get_topics
topics = get_topics('robust04')
print(f'{len(topics)} queries total')

250 queries total


The topics are in a dictionary, whose keys are integers uniquely identifying each query. Each topic contains the following fields:

* `title`: TREC's term for the brief query a user might actually type;
* `description`: a longer form of the query in the form of a complete sentence; and
* `narrative`: a description of what the user is looking for and what kinds of results would be relevant or non-relevant.

In [None]:
topics[301]

{'narrative': 'A relevant document must as a minimum identify the organization and the type of illegal activity (e.g., Columbian cartel exporting cocaine). Vague references to international drug trade without identification of the organization(s) involved would not be relevant.',
 'description': 'Identify organizations that participate in international criminal activity, the activity, and, if possible, collaborating organizations and the countries involved.',
 'title': 'International Organized Crime'}

For the purpose of your experiments, we'll divide them into a development and test set.

In [9]:
dev_topics = {k:topics[k] for k in list(topics.keys())[:125]}
test_topics = {k:topics[k] for k in list(topics.keys())[125:]}

Now, we'll fetch the relevance judgments for the Robust04 queries, which TREC calls "qrels".

In [10]:
from urllib.request import urlopen

qfile = 'https://github.com/castorini/anserini-tools/blob/63ceeab1dd94c1221f29b931d868e8fab67cc25c/topics-and-qrels/qrels.robust04.txt?raw=true'
qrels = []
for line in urlopen(qfile):
  qid, round, docid, score = line.strip().split()
  qrels.append([int(qid), 0, docid.decode('UTF-8'), int(score)])
#qrels = [line.strip().split() for line in urlopen(qfile)]

Each record in the qrel contains four fields:

1. the numeric identifier of the query;
2. the round of relevance feedback, which is here always 0;
3. the identifier of a documennt that has been judged; and
4. the relevance score of that document.

In Robust04, all relevance judgments are binary, i.e., 1 or 0. Note that not all non-relevant documents are recorded. The qrel file only contains those documents the annotators actually looked at; the vast majority of documents in the collection have not been judged. In IR evaluation, we assume that unannotated documents are non-relevant.

In [11]:
qrels[0:10]

[[301, 0, 'FBIS3-10082', 1],
 [301, 0, 'FBIS3-10169', 0],
 [301, 0, 'FBIS3-10243', 1],
 [301, 0, 'FBIS3-10319', 0],
 [301, 0, 'FBIS3-10397', 1],
 [301, 0, 'FBIS3-10491', 1],
 [301, 0, 'FBIS3-10555', 0],
 [301, 0, 'FBIS3-10622', 1],
 [301, 0, 'FBIS3-10634', 0],
 [301, 0, 'FBIS3-10635', 0]]

We collect the top 1000 hists for both the dev and test sets. You

In [12]:
# Compute top-1000 lists for queries in test_topics
def topic_hits(searcher, topics, k=1000):
  hits = {}
  for topic, info in topics.items():
    print(topic, info['title'])
    hits[topic] = [(hit.docid, hit.score) for hit in searcher.search(info['title'], k)]
  return hits

dev_hits = topic_hits(searcher, dev_topics)
test_hits = topic_hits(searcher, test_topics)

350 Health and Computer Terminals
351 Falkland petroleum exploration
352 British Chunnel impact
353 Antarctica exploration
354 journalist risks
355 ocean remote sensing
356 postmenopausal estrogen Britain
357 territorial waters dispute
358 blood-alcohol fatalities
359 mutual fund predictors
360 drug legalization benefits
361 clothing sweatshops
362 human smuggling
363 transportation tunnel disasters
364 rabies
365 El Nino
366 commercial cyanide uses
367 piracy
368 in vitro fertilization
369 anorexia nervosa bulimia
370 food/drug laws
371 health insurance holistic
372 Native American casino
373 encryption equipment export
374 Nobel prize winners
375 hydrogen energy
376 World Court
377 cigar smoking
378 euro opposition
379 mainstreaming
380 obesity medical treatment
381 alternative medicine
382 hydrogen fuel automobiles
383 mental illness drugs
384 space station moon
385 hybrid fuel cars
386 teaching disabled children
387 radioactive waste
388 organic soil enhancement
389 illegal technol

## Evaluating Initial Ranked Lists



When reranking, an important metric is the _recall_ of the initial set of results. This tells us the upper bound or &ldquo;headroom&rdquo; on the improvements that reranking can achieve. If the recall in the initial ranked lists is too low, we know we need to optimize the initial retrieval model.

For this assignment, you will work with fixed initial ranked lists from pyserini's BM25 model, but it's still useful to see how much room there is for improvement during reranking.

As before, you should process the `qrels` data to find the relevant results for each query.

In [28]:
from collections import defaultdict

qrels_map = defaultdict(lambda: defaultdict(lambda: 0))
rel_count_map = defaultdict(lambda: 0)
for qid, _, docid, is_rel in qrels:

  qrels_map[qid][docid] = is_rel
  rel_count_map[qid] += is_rel

In [29]:
print(dev_hits[350])
print(qrels_map[350])
print(rel_count_map[350])

[('LA052290-0188', 8.458999633789062), ('LA060690-0112', 7.700099945068359), ('FT922-6787', 7.6209001541137695), ('FT931-7146', 7.403800010681152), ('FT923-4291', 7.059599876403809), ('FR940805-2-00111', 6.98390007019043), ('FR940919-2-00012', 6.932000160217285), ('LA100690-0107', 6.867800235748291), ('FR940822-0-00016', 6.8653998374938965), ('FR940610-0-00106', 6.82450008392334), ('FT931-7096', 6.816500186920166), ('LA062390-0104', 6.751399993896484), ('FR940805-2-00036', 6.702400207519531), ('FR941107-2-00067', 6.663400173187256), ('FR940610-1-00047', 6.576900005340576), ('FT943-12885', 6.5345001220703125), ('FT944-11341', 6.5015997886657715), ('FR940826-0-00055', 6.480000019073486), ('LA010289-0072', 6.478499889373779), ('FR940913-0-00079', 6.470300197601318), ('FBIS3-22305', 6.45959997177124), ('FT944-12140', 6.456099987030029), ('FR940706-0-00094', 6.455100059509277), ('FR940825-0-00110', 6.455099105834961), ('FBIS3-59668', 6.378499984741211), ('FR940610-0-00127', 6.37540006637573

In [31]:
## TODO [15 points]: Compute Recall@1000 for the dev_hits and test_hits data
## and print it out.

def comp_recall_1000(hits):

  res = []
  for qid in hits.keys():

    true_pos = 0
    for docid, _ in hits[qid]:

      if qrels_map[qid][docid] == 1: true_pos += 1

    if rel_count_map[qid] > 0:

      res.append((qid, true_pos / rel_count_map[qid]))

    else: res.append((qid, 0))

  return res

print("Recall@1000 for dev_hits: ", comp_recall_1000(dev_hits))
print("Recall@1000 for test_hits: ", comp_recall_1000(test_hits))

Recall@1000 for dev_hits:  [(350, 0.3235294117647059), (351, 0.6458333333333334), (352, 0.11788617886178862), (353, 0.5655737704918032), (354, 0.32409972299168976), (355, 0.5333333333333333), (356, 0.17647058823529413), (357, 0.562962962962963), (358, 0.9411764705882353), (359, 0.35714285714285715), (360, 0.7748344370860927), (361, 0.7777777777777778), (362, 0.6410256410256411), (363, 0.625), (364, 0.9428571428571428), (365, 1.0), (366, 0.9292929292929293), (367, 0.37566137566137564), (368, 0.6229508196721312), (369, 0.6923076923076923), (370, 0.4880952380952381), (371, 0.29411764705882354), (372, 0.4489795918367347), (373, 0.7878787878787878), (374, 0.7843137254901961), (375, 0.575), (376, 0.22549019607843138), (377, 0.9487179487179487), (378, 0.20408163265306123), (379, 0.3125), (380, 0.8571428571428571), (381, 0.32142857142857145), (382, 0.9090909090909091), (383, 0.2054794520547945), (384, 1.0), (385, 0.6976744186046512), (386, 0.5789473684210527), (387, 0.8235294117647058), (388, 

For a given set of top-1000 lists, Recall@1000 will not change after reranking. What will change are ranking-based metrics like MAP and NDCG. You should compute MAP@1000 for the initial `dev_hits` and `test_hits` data.

In [None]:
## TODO [10 points]: Adapt your code from Homework 3 to compute MAP@1000 for
## the dev_hits and test_hits data and print it out.



## Reranking Search Results

In this final part of the assignment, you should implement a ranking function that, hopefully, improves on the baseline BM25 ranking. You may use the BM25 score for each document as input, as well as the query, of course, and any other properties of the documents you look up with the `reader` object.  After computing a new score for each candidate, re-sort the top-1000 results by your model's score.

You may use anything you've learned in this course---or in another course---to build your ranking function. For example, you could implement pseudo-relevance feedback or a relevance model, which would treat the top of each ranked list (e.g., the top 100) as if it were truly relevant and retrain model parameters. You could tune different BM25, query likelihood, or sequential dependence model parameters. You could try to learn different weights or embeddings for different fields in documents. You could use implementations of transformer language models such as BERT or SentenceBERT to score the compatibility of queries and documents. To be clear, you don't have to any of these approaches; you are free to try whatever ideas you like.

If your reranking model has tunable parameters, you should tune them on the `dev_hits` set. In the end, you will also evaluate MAP@1000 on the `test_hits` set.

**TODO**: Put any explanation of your reranking function here.

In [None]:
## TODO [70 points]: Implement a reranking function that takes a query, the
## reader, and an initial ranking and computes new scores.
## Like BM25, higher should be better.
## If you train parameters or set hyperparameters for this ranking function,
## do that here, as well.

In [None]:
## TODO [5 points]: Compute and print out the MAP@1000 score after reranking
## on dev_hits and test_hits.