##Install and imports

### install

In [None]:
%%capture
%%bash
pip install -q -U bitsandbytes
pip install -q -U git+https://github.com/huggingface/transformers.git
pip install -q -U git+https://github.com/huggingface/peft.git
pip install -q -U git+https://github.com/huggingface/accelerate.git

In [None]:
!pip install jsonlines

Collecting jsonlines
  Downloading jsonlines-4.0.0-py3-none-any.whl (8.7 kB)
Installing collected packages: jsonlines
Successfully installed jsonlines-4.0.0


### import

In [None]:
import os
import json
import tqdm
import sys
import gzip
from collections import defaultdict

## Load model

In [None]:
!pip install torch==2.3.0

Collecting torch==2.3.0
  Downloading torch-2.3.0-cp310-cp310-manylinux1_x86_64.whl (779.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m779.1/779.1 MB[0m [31m1.4 MB/s[0m eta [36m0:00:00[0m
Collecting nvidia-nccl-cu12==2.20.5 (from torch==2.3.0)
  Downloading nvidia_nccl_cu12-2.20.5-py3-none-manylinux2014_x86_64.whl (176.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m176.2/176.2 MB[0m [31m6.9 MB/s[0m eta [36m0:00:00[0m
Collecting triton==2.3.0 (from torch==2.3.0)
  Downloading triton-2.3.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (168.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m168.1/168.1 MB[0m [31m6.8 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: triton, nvidia-nccl-cu12, torch
  Attempting uninstall: triton
    Found existing installation: triton 2.2.0
    Uninstalling triton-2.2.0:
      Successfully uninstalled triton-2.2.0
  Attempting uninstall: nvidia-nccl-cu12
  

In [None]:
%%capture
from transformers import AutoTokenizer, LlamaForCausalLM, AutoModelForCausalLM

model_name = "HuggingFaceH4/zephyr-7b-beta"
tokenizer = AutoTokenizer.from_pretrained(model_name, truncation=True, padding=True, padding_side="left", maximum_length = 2048, model_max_length = 2048)
model = AutoModelForCausalLM.from_pretrained(model_name, load_in_4bit = True, device_map = 'auto')
tokenizer.pad_token = tokenizer.eos_token
model.generation_config.pad_token_id = model.generation_config.eos_token_id

## Query expansion

### Load queries, candidate documents and qrels

In [None]:
# Add needed imports and donwload queries and candidate documents files
!pip install sentence-transformers
import requests
import logging
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
logging.basicConfig(format='%(asctime)s - %(message)s',datefmt='%Y-%m-%d %H:%M:%S')

!wget https://msmarco.z22.web.core.windows.net/msmarcoranking/msmarco-test2019-queries.tsv.gz
!wget https://msmarco.z22.web.core.windows.net/msmarcoranking/msmarco-passagetest2019-top1000.tsv.gz

Collecting sentence-transformers
  Downloading sentence_transformers-2.7.0-py3-none-any.whl (171 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/171.5 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m171.5/171.5 kB[0m [31m5.2 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: sentence-transformers
Successfully installed sentence-transformers-2.7.0
--2024-05-19 16:12:27--  https://msmarco.z22.web.core.windows.net/msmarcoranking/msmarco-test2019-queries.tsv.gz
Resolving msmarco.z22.web.core.windows.net (msmarco.z22.web.core.windows.net)... 20.150.34.1
Connecting to msmarco.z22.web.core.windows.net (msmarco.z22.web.core.windows.net)|20.150.34.1|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 4276 (4.2K) [application/x-gzip]
Saving to: ‘msmarco-test2019-queries.tsv.gz’


2024-05-19 16:12:27 (1.25 GB/s) - ‘msmarco-test2019-queries.tsv.gz’ saved [4276/4276]

--2024-05-19 16:1

In [None]:
# Download qrels file
def download_txt_file(url, save_path):
    response = requests.get(url)
    if response.status_code == 200:
        with open(save_path, 'wb') as file:
            file.write(response.content)
        print("File downloaded successfully.")
    else:
        print("Failed to download file.")

# Example usage
url = "https://trec.nist.gov/data/deep/2019qrels-pass.txt"
save_path = "qrels.txt"
download_txt_file(url, save_path)

DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): trec.nist.gov:443
DEBUG:urllib3.connectionpool:https://trec.nist.gov:443 "GET /data/deep/2019qrels-pass.txt HTTP/1.1" 200 187092


File downloaded successfully.


In [None]:
# Retrieve the queries
from sentence_transformers import util

queries = {}
queries_filepath = os.path.join('/content/msmarco-test2019-queries.tsv.gz')
if not os.path.exists(queries_filepath):
    logging.info("Download "+os.path.basename(queries_filepath))
    util.http_get('https://msmarco.z22.web.core.windows.net/msmarcoranking/msmarco-test2019-queries.tsv.gz', queries_filepath)

with gzip.open(queries_filepath, 'rt', encoding='utf8') as fIn:
    for line in fIn:
        qid, query = line.strip().split("\t")
        queries[qid] = query

In [None]:
#Read which passages are relevant
relevant_docs = defaultdict(lambda: defaultdict(int))
qrels_filepath = os.path.join("/content/qrels.txt")

if not os.path.exists(qrels_filepath):
    logging.info("Download "+os.path.basename(qrels_filepath))
    util.http_get('https://trec.nist.gov/data/deep/2019qrels-pass.txt', qrels_filepath)


with open(qrels_filepath) as fIn:
    for line in fIn:
        qid, _, pid, score = line.strip().split()
        score = int(score)
        if score > 0:
            relevant_docs[qid][pid] = score

# Only use queries that have at least one relevant passage
relevant_qid = []
for qid in queries:
    if len(relevant_docs[qid]) > 0:
        relevant_qid.append(qid)

In [None]:
# Read the top 1000 passages that are supposed to be re-ranked
passage_filepath = os.path.join("/content/msmarco-passagetest2019-top1000.tsv.gz")

if not os.path.exists(passage_filepath):
    logging.info("Download "+os.path.basename(passage_filepath))
    util.http_get('https://msmarco.z22.web.core.windows.net/msmarcoranking/msmarco-passagetest2019-top1000.tsv.gz', passage_filepath)



passage_cand = {}
with gzip.open(passage_filepath, 'rt', encoding='utf8') as fIn:
    for line in fIn:
        qid, pid, query, passage = line.strip().split("\t")
        if qid not in passage_cand:
            passage_cand[qid] = []

        passage_cand[qid].append([pid, passage])

logging.info("Queries: {}".format(len(queries)))

INFO:root:Queries: 200


### Prtomp queries


In [None]:
# Define the prompt template for query reformulation
prompt_template = "Answer the following query: {{query}} Give the rationale before answering"

In [None]:
def reformulate_query_with_cot(original_query, model, tokenizer, max_length=512):
    # Using the prompt template from the scientific paper
    prompt_template = "Answer the following query: {query} Give the rationale before answering."
    prompt = prompt_template.format(query=original_query)

    # Encode and generate response
    inputs = tokenizer(prompt, return_tensors='pt', padding=True, truncation=True, max_length=max_length).to(model.device)
    outputs = model.generate(**inputs, max_length=max_length, num_beams=5, no_repeat_ngram_size=2, early_stopping=True)

    # Decode the generated tokens to string
    new_query = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return new_query

In [None]:
!pip install tqdm



In [None]:
from tqdm import tqdm

# Generate reformulated queries using the updated function
reformulated_queries = {}
for qid in tqdm(relevant_qid, desc="Processing queries"):
    query = queries.get(qid, None)
    if query:
        reformulated_queries[qid] = reformulate_query_with_cot(query, model, tokenizer)

DEBUG:tensorflow:Falling back to TensorFlow client; we recommended you install the Cloud TPU client directly with pip install cloud-tpu-client.
DEBUG:h5py._conv:Creating converter from 7 to 5
DEBUG:h5py._conv:Creating converter from 5 to 7
DEBUG:h5py._conv:Creating converter from 7 to 5
DEBUG:h5py._conv:Creating converter from 5 to 7
DEBUG:jax._src.path:etils.epath found. Using etils.epath for file I/O.
INFO:numexpr.utils:NumExpr defaulting to 2 threads.
Processing queries: 100%|██████████| 43/43 [1:45:21<00:00, 147.00s/it]


In [None]:
print(reformulated_queries)



In [None]:
print(queries)

{'1108939': 'what slows down the flow of blood', '1112389': 'what is the county for grand rapids, mn', '792752': 'what is ruclip', '1119729': 'what do you do when you have a nosebleed from having your nose', '1105095': 'where is sugar lake lodge located', '1105103': 'where is steph currys home in nc', '1128373': 'iur definition', '1127622': 'meaning of heat capacity', '1124979': 'synonym for treatment', '885490': 'what party is paul ryan in', '1119827': 'cast of sky captain and the world of tomorrow', '190044': 'foods to detox liver naturally', '500575': "sop's policy", '883785': 'what origin is the last name goins', '264403': 'how long is recovery from a face lift and neck lift', '1108100': 'what type of movement do bacteria exhibit?', '421756': 'is prorate the same as daily rate', '1108307': 'what trail did thousands use to get to the gold rush', '966413': 'where are the benefits of cinnamon as a supplement?', '1111546': 'what is the medium for an artisan', '156493': 'do goldfish gro

### Evaluate cross-encoder
The comparison will be made on TinyBERT cross-encoder


In [None]:
!pip install pytrec_eval

Collecting pytrec_eval
  Downloading pytrec_eval-0.5.tar.gz (15 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: pytrec_eval
  Building wheel for pytrec_eval (setup.py) ... [?25l[?25hdone
  Created wheel for pytrec_eval: filename=pytrec_eval-0.5-cp310-cp310-linux_x86_64.whl size=308216 sha256=8b5a40e205c811fe5be2f409e99d43a5978e36393eabd44bb39a519380788b27
  Stored in directory: /root/.cache/pip/wheels/51/3a/cd/dcc1ddfc763987d5cb237165d8ac249aa98a23ab90f67317a8
Successfully built pytrec_eval
Installing collected packages: pytrec_eval
Successfully installed pytrec_eval-0.5


Evaluation without prompted queries

In [None]:
import tqdm
import numpy as np
import pytrec_eval
from sentence_transformers.cross_encoder import CrossEncoder

# model_save_path = os.path.join("/content/MiniLM")
#model_save_path = os.path.join("/content/distilroBERTa")
model_save_path = os.path.join("/content/TinyBERT")


queries_result_list = []
run = {}
model = CrossEncoder(model_save_path, max_length=512)

for qid in tqdm.tqdm(relevant_qid):
    query = queries[qid]

    cand = passage_cand[qid]
    pids = [c[0] for c in cand]
    corpus_sentences = [c[1] for c in cand]

    cross_inp = [[query, sent] for sent in corpus_sentences]

    if model.config.num_labels > 1: #Cross-Encoder that predict more than 1 score, we use the last and apply softmax
        cross_scores = model.predict(cross_inp, apply_softmax=True)[:, 1].tolist()
    else:
        cross_scores = model.predict(cross_inp).tolist()

    cross_scores_sparse = {}
    for idx, pid in enumerate(pids):
        cross_scores_sparse[pid] = cross_scores[idx]

    sparse_scores = cross_scores_sparse
    run[qid] = {}
    for pid in sparse_scores:
        run[qid][pid] = float(sparse_scores[pid])

INFO:sentence_transformers.cross_encoder.CrossEncoder:Use pytorch device: cuda
  0%|          | 0/43 [00:00<?, ?it/s]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

  2%|▏         | 1/43 [00:00<00:26,  1.57it/s]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

  5%|▍         | 2/43 [00:01<00:23,  1.74it/s]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

  7%|▋         | 3/43 [00:01<00:22,  1.80it/s]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

  9%|▉         | 4/43 [00:02<00:20,  1.91it/s]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 12%|█▏        | 5/43 [00:02<00:19,  1.92it/s]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 14%|█▍        | 6/43 [00:03<00:18,  1.98it/s]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 16%|█▋        | 7/43 [00:03<00:18,  1.92it/s]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 19%|█▊        | 8/43 [00:04<00:17,  1.96it/s]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 21%|██        | 9/43 [00:04<00:17,  1.95it/s]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 23%|██▎       | 10/43 [00:05<00:16,  1.99it/s]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 26%|██▌       | 11/43 [00:05<00:17,  1.81it/s]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 28%|██▊       | 12/43 [00:06<00:19,  1.60it/s]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 30%|███       | 13/43 [00:07<00:20,  1.48it/s]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 33%|███▎      | 14/43 [00:08<00:20,  1.40it/s]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 35%|███▍      | 15/43 [00:08<00:19,  1.41it/s]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 37%|███▋      | 16/43 [00:09<00:17,  1.54it/s]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 40%|███▉      | 17/43 [00:09<00:15,  1.66it/s]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 42%|████▏     | 18/43 [00:10<00:14,  1.72it/s]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 44%|████▍     | 19/43 [00:10<00:13,  1.82it/s]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 47%|████▋     | 20/43 [00:11<00:12,  1.87it/s]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 49%|████▉     | 21/43 [00:11<00:11,  1.95it/s]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 51%|█████     | 22/43 [00:12<00:10,  1.92it/s]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 53%|█████▎    | 23/43 [00:12<00:10,  1.94it/s]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 56%|█████▌    | 24/43 [00:13<00:09,  1.95it/s]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 58%|█████▊    | 25/43 [00:13<00:09,  1.98it/s]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 60%|██████    | 26/43 [00:14<00:08,  2.00it/s]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 63%|██████▎   | 27/43 [00:15<00:08,  1.94it/s]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 65%|██████▌   | 28/43 [00:15<00:07,  1.94it/s]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 67%|██████▋   | 29/43 [00:16<00:07,  1.94it/s]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 70%|██████▉   | 30/43 [00:16<00:06,  1.95it/s]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 72%|███████▏  | 31/43 [00:17<00:06,  1.93it/s]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 74%|███████▍  | 32/43 [00:17<00:05,  1.95it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 79%|███████▉  | 34/43 [00:18<00:04,  2.02it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 84%|████████▎ | 36/43 [00:19<00:03,  2.23it/s]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 86%|████████▌ | 37/43 [00:20<00:03,  1.93it/s]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 88%|████████▊ | 38/43 [00:20<00:03,  1.64it/s]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 91%|█████████ | 39/43 [00:21<00:02,  1.49it/s]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 93%|█████████▎| 40/43 [00:22<00:02,  1.45it/s]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 95%|█████████▌| 41/43 [00:23<00:01,  1.57it/s]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 98%|█████████▊| 42/43 [00:23<00:00,  1.64it/s]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

100%|██████████| 43/43 [00:24<00:00,  1.78it/s]


In [None]:
evaluator = pytrec_eval.RelevanceEvaluator(relevant_docs, {'ndcg_cut.10', 'recall_100', 'map_cut.1000'})
scores = evaluator.evaluate(run)

print("Queries:", len(relevant_qid))
print("NDCG@10: {:.2f}".format(np.mean([ele["ndcg_cut_10"] for ele in scores.values()])*100))
print("Recall@100: {:.2f}".format(np.mean([ele["recall_100"] for ele in scores.values()])*100))
print("MAP@1000: {:.2f}".format(np.mean([ele["map_cut_1000"] for ele in scores.values()])*100))

Queries: 43
NDCG@10: 69.90
Recall@100: 50.47
MAP@1000: 45.55


Evaluation with prompted queries

In [None]:
prompted_queries_result_list = []
prompted_run = {}
model = CrossEncoder(model_save_path, max_length=512)

for qid in tqdm.tqdm(relevant_qid):
    query = reformulated_queries[qid]

    cand = passage_cand[qid]
    pids = [c[0] for c in cand]
    corpus_sentences = [c[1] for c in cand]

    cross_inp = [[query, sent] for sent in corpus_sentences]

    if model.config.num_labels > 1: #Cross-Encoder that predict more than 1 score, we use the last and apply softmax
        cross_scores = model.predict(cross_inp, apply_softmax=True)[:, 1].tolist()
    else:
        cross_scores = model.predict(cross_inp).tolist()

    cross_scores_sparse = {}
    for idx, pid in enumerate(pids):
        cross_scores_sparse[pid] = cross_scores[idx]

    sparse_scores = cross_scores_sparse
    prompted_run[qid] = {}
    for pid in sparse_scores:
        prompted_run[qid][pid] = float(sparse_scores[pid])

INFO:sentence_transformers.cross_encoder.CrossEncoder:Use pytorch device: cuda
  0%|          | 0/43 [00:00<?, ?it/s]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

  2%|▏         | 1/43 [00:01<00:58,  1.38s/it]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

  5%|▍         | 2/43 [00:02<00:57,  1.39s/it]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

  7%|▋         | 3/43 [00:05<01:12,  1.82s/it]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

  9%|▉         | 4/43 [00:07<01:22,  2.11s/it]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 12%|█▏        | 5/43 [00:10<01:36,  2.53s/it]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 14%|█▍        | 6/43 [00:11<01:12,  1.96s/it]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 16%|█▋        | 7/43 [00:14<01:21,  2.26s/it]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 19%|█▊        | 8/43 [00:15<01:04,  1.84s/it]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 21%|██        | 9/43 [00:17<01:00,  1.78s/it]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 23%|██▎       | 10/43 [00:18<00:56,  1.72s/it]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 26%|██▌       | 11/43 [00:20<00:58,  1.83s/it]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 28%|██▊       | 12/43 [00:23<01:05,  2.11s/it]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 30%|███       | 13/43 [00:24<00:52,  1.74s/it]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 33%|███▎      | 14/43 [00:25<00:42,  1.45s/it]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 35%|███▍      | 15/43 [00:25<00:33,  1.21s/it]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 37%|███▋      | 16/43 [00:27<00:38,  1.42s/it]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 40%|███▉      | 17/43 [00:28<00:31,  1.20s/it]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 42%|████▏     | 18/43 [00:29<00:29,  1.19s/it]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 44%|████▍     | 19/43 [00:31<00:31,  1.31s/it]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 47%|████▋     | 20/43 [00:34<00:43,  1.88s/it]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 49%|████▉     | 21/43 [00:35<00:35,  1.62s/it]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 51%|█████     | 22/43 [00:37<00:36,  1.75s/it]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 53%|█████▎    | 23/43 [00:38<00:32,  1.62s/it]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 56%|█████▌    | 24/43 [00:40<00:29,  1.54s/it]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 58%|█████▊    | 25/43 [00:41<00:24,  1.35s/it]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 60%|██████    | 26/43 [00:42<00:21,  1.27s/it]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 63%|██████▎   | 27/43 [00:43<00:22,  1.39s/it]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 65%|██████▌   | 28/43 [00:45<00:22,  1.47s/it]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 67%|██████▋   | 29/43 [00:48<00:25,  1.84s/it]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 70%|██████▉   | 30/43 [00:50<00:24,  1.88s/it]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 72%|███████▏  | 31/43 [00:51<00:20,  1.67s/it]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 74%|███████▍  | 32/43 [00:52<00:17,  1.61s/it]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 79%|███████▉  | 34/43 [00:54<00:10,  1.15s/it]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 84%|████████▎ | 36/43 [00:55<00:06,  1.11it/s]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 86%|████████▌ | 37/43 [00:57<00:06,  1.12s/it]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 88%|████████▊ | 38/43 [00:59<00:07,  1.52s/it]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 91%|█████████ | 39/43 [01:01<00:06,  1.57s/it]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 93%|█████████▎| 40/43 [01:03<00:04,  1.59s/it]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 95%|█████████▌| 41/43 [01:05<00:03,  1.68s/it]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 98%|█████████▊| 42/43 [01:06<00:01,  1.71s/it]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

100%|██████████| 43/43 [01:08<00:00,  1.59s/it]


In [None]:
evaluator_prompted = pytrec_eval.RelevanceEvaluator(relevant_docs, {'ndcg_cut.10', 'recall_100', 'map_cut.1000'})
scores_prompted = evaluator_prompted.evaluate(prompted_run)

print("TINYBER WITH PROMPTING")
print("Queries:", len(reformulated_queries))
print("NDCG@10: {:.2f}".format(np.mean([ele["ndcg_cut_10"] for ele in scores_prompted.values()])*100))
print("Recall@100: {:.2f}".format(np.mean([ele["recall_100"] for ele in scores_prompted.values()])*100))
print("MAP@1000: {:.2f}".format(np.mean([ele["map_cut_1000"] for ele in scores_prompted.values()])*100))

TINYBER WITH PROMPTING
Queries: 43
NDCG@10: 46.89
Recall@100: 40.73
MAP@1000: 32.94
