In [1]:
!pip install sentence-transformers



In [2]:
!pip install pytrec_eval



In [3]:
"""
This examples show how to train a Cross-Encoder for the MS Marco dataset (https://github.com/microsoft/MSMARCO-Passage-Ranking).

The query and the passage are passed simoultanously to a Transformer network. The network then returns
a score between 0 and 1 how relevant the passage is for a given query.

The resulting Cross-Encoder can then be used for passage re-ranking: You retrieve for example 100 passages
for a given query, for example with ElasticSearch, and pass the query+retrieved_passage to the CrossEncoder
for scoring. You sort the results then according to the output of the CrossEncoder.

This gives a significant boost compared to out-of-the-box ElasticSearch / BM25 ranking.
"""
from torch.utils.data import DataLoader
from sentence_transformers import LoggingHandler, util
from sentence_transformers.cross_encoder import CrossEncoder
from sentence_transformers.cross_encoder.evaluation import CERerankingEvaluator
from sentence_transformers import InputExample
from datetime import datetime
import gzip
import os
import tarfile
import tqdm
import logging
from collections import defaultdict
import numpy as np
import sys
import pytrec_eval
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
logging.basicConfig(format='%(asctime)s - %(message)s',datefmt='%Y-%m-%d %H:%M:%S')

In [4]:
!wget https://msmarco.z22.web.core.windows.net/msmarcoranking/queries.tar.gz
!tar -xvzf  queries.tar.gz

--2024-05-26 19:57:39--  https://msmarco.z22.web.core.windows.net/msmarcoranking/queries.tar.gz
Resolving msmarco.z22.web.core.windows.net (msmarco.z22.web.core.windows.net)... 20.150.34.1
Connecting to msmarco.z22.web.core.windows.net (msmarco.z22.web.core.windows.net)|20.150.34.1|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 18882551 (18M) [application/gzip]
Saving to: ‘queries.tar.gz.2’


2024-05-26 19:57:39 (34.7 MB/s) - ‘queries.tar.gz.2’ saved [18882551/18882551]

queries.dev.tsv
queries.eval.tsv
queries.train.tsv


读取三个文件

In [5]:
"""
This file evaluates CrossEncoder on the TREC 2019 Deep Learning (DL) Track: https://arxiv.org/abs/2003.07820

TREC 2019 DL is based on the corpus of MS Marco. MS Marco provides a sparse annotation, i.e., usually only a single
passage is marked as relevant for a given query. Many other highly relevant passages are not annotated and hence are treated
as an error if a model ranks those high.

TREC DL instead annotated up to 200 passages per query for their relevance to a given query. It is better suited to estimate
the model performance for the task of reranking in Information Retrieval.

Run:
python eval_cross-encoder-trec-dl.py cross-encoder-model-name

"""


data_folder = 'trec2019-data'
os.makedirs(data_folder, exist_ok=True)

#Read test queries
queries = {}
queries_filepath = os.path.join(data_folder, 'msmarco-test2019-queries.tsv.gz')
if not os.path.exists(queries_filepath):
    logging.info("Download "+os.path.basename(queries_filepath))
    util.http_get('https://msmarco.z22.web.core.windows.net/msmarcoranking/msmarco-test2019-queries.tsv.gz', queries_filepath)

with gzip.open(queries_filepath, 'rt', encoding='utf8') as fIn:
    for line in fIn:
        qid, query = line.strip().split("\t")
        queries[qid] = query

#Read which passages are relevant
relevant_docs = defaultdict(lambda: defaultdict(int))
qrels_filepath = os.path.join(data_folder, '2019qrels-pass.txt')

if not os.path.exists(qrels_filepath):
    logging.info("Download "+os.path.basename(qrels_filepath))
    util.http_get('https://trec.nist.gov/data/deep/2019qrels-pass.txt', qrels_filepath)


with open(qrels_filepath) as fIn:
    for line in fIn:
        qid, _, pid, score = line.strip().split()
        score = int(score)
        if score > 0:
            relevant_docs[qid][pid] = score

# Only use queries that have at least one relevant passage
relevant_qid = []
for qid in queries:
    if len(relevant_docs[qid]) > 0:
        relevant_qid.append(qid)


# Read the top 1000 passages that are supposed to be re-ranked
passage_filepath = os.path.join(data_folder, 'msmarco-passagetest2019-top1000.tsv.gz')

if not os.path.exists(passage_filepath):
    logging.info("Download "+os.path.basename(passage_filepath))
    util.http_get('https://msmarco.z22.web.core.windows.net/msmarcoranking/msmarco-passagetest2019-top1000.tsv.gz', passage_filepath)



passage_cand = {}
with gzip.open(passage_filepath, 'rt', encoding='utf8') as fIn:
    for line in fIn:
        qid, pid, query, passage = line.strip().split("\t")
        if qid not in passage_cand:
            passage_cand[qid] = []

        passage_cand[qid].append([pid, passage])

logging.info("Queries: {}".format(len(queries)))

print("queries\n", queries)
print("relevant_docs\n", relevant_docs)
print("relevant_qid\n", relevant_qid)
print("passage_cand\n", passage_cand)

INFO:root:Queries: 200


queries
 {'1108939': 'what slows down the flow of blood', '1112389': 'what is the county for grand rapids, mn', '792752': 'what is ruclip', '1119729': 'what do you do when you have a nosebleed from having your nose', '1105095': 'where is sugar lake lodge located', '1105103': 'where is steph currys home in nc', '1128373': 'iur definition', '1127622': 'meaning of heat capacity', '1124979': 'synonym for treatment', '885490': 'what party is paul ryan in', '1119827': 'cast of sky captain and the world of tomorrow', '190044': 'foods to detox liver naturally', '500575': "sop's policy", '883785': 'what origin is the last name goins', '264403': 'how long is recovery from a face lift and neck lift', '1108100': 'what type of movement do bacteria exhibit?', '421756': 'is prorate the same as daily rate', '1108307': 'what trail did thousands use to get to the gold rush', '966413': 'where are the benefits of cinnamon as a supplement?', '1111546': 'what is the medium for an artisan', '156493': 'do gol

IOPub data rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_data_rate_limit`.

Current values:
NotebookApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
NotebookApp.rate_limit_window=3.0 (secs)



In [6]:
for qid in relevant_docs:
    if qid in passage_cand:
        top_docs = sorted(relevant_docs[qid].items(), key=lambda x: x[1], reverse=True)[:3]
        print(f"Top 3 relevant passages for query ID {qid}:")
        for pid, score in top_docs:
            for p in passage_cand[qid]:
                if p[0] == pid:
                    print(f"PID: {pid}, Score: {score}, Passage: {p[1]}")
                    break

Top 3 relevant passages for query ID 19335:
Top 3 relevant passages for query ID 47923:
PID: 473807, Score: 3, Passage: Quick Answer. The function of a synaptic knob is to change the action potential that is carried by axons into a chemical message. The chemical message then interacts with the recipient neuron or effector. This process is called synaptic transmission.
PID: 5417582, Score: 3, Passage: noun, singular: axon terminal. Button-like endings of axons through which axons make synaptic contacts with other nerve cells or with effector cells. Supplement. Axon terminals refer to the axon endings that are somewhat enlarged and often club-or button-shaped.Axon terminals are that part of a nerve cell that make synaptic connections with another nerve cell or with an effector cell (e.g. muscle cell or gland cell).Axon terminals contain various neurotransmitters and release them at the synapse.oun, singular: axon terminal. Button-like endings of axons through which axons make synaptic co

In [7]:
for qid in relevant_docs:
    if qid in passage_cand:
        top_docs = sorted(relevant_docs[qid].items(), key=lambda x: x[1], reverse=True)[:3]
        print(f"Top 3 relevant passages for query ID {qid}:")
        found_passage = False
        for pid, score in top_docs:
            passage_found = False
            for p in passage_cand[qid]:
                if p[0] == pid:
                    print(f"PID: {pid}, Score: {score}, Passage: {p[1]}")
                    found_passage = True
                    passage_found = True
                    break
            if not passage_found:
                print(f"Passage ID {pid} not found in the top 1000 passages list.")
        if not found_passage:
            print(f"No relevant passages found or printed for query ID {qid}.")

Top 3 relevant passages for query ID 19335:
Passage ID 3175481 not found in the top 1000 passages list.
Passage ID 3175484 not found in the top 1000 passages list.
Passage ID 8412682 not found in the top 1000 passages list.
No relevant passages found or printed for query ID 19335.
Top 3 relevant passages for query ID 47923:
PID: 473807, Score: 3, Passage: Quick Answer. The function of a synaptic knob is to change the action potential that is carried by axons into a chemical message. The chemical message then interacts with the recipient neuron or effector. This process is called synaptic transmission.
PID: 5417582, Score: 3, Passage: noun, singular: axon terminal. Button-like endings of axons through which axons make synaptic contacts with other nerve cells or with effector cells. Supplement. Axon terminals refer to the axon endings that are somewhat enlarged and often club-or button-shaped.Axon terminals are that part of a nerve cell that make synaptic connections with another nerve c

In [8]:
# Initialize a dictionary
top_docs = {}

for qid in relevant_docs:
    if qid in passage_cand:
        top_three_docs = sorted(relevant_docs[qid].items(), key=lambda x: x[1], reverse=True)[:3]
        top_docs_text = []

        for pid, score in top_three_docs:
            passage_found = False
            for p in passage_cand[qid]:
                if p[0] == pid:
                    top_docs_text.append(p[1])
                    passage_found = True
                    break
            if not passage_found:
                print(f"Passage ID {pid} not found in the top 1000 passages list.")

        if len(top_docs_text) == 3:
            top_docs[qid] = top_docs_text
        else:
            print(f"Less than 3 relevant passages found for query ID {qid}, only found {len(top_docs_text)}.")

Passage ID 3175481 not found in the top 1000 passages list.
Passage ID 3175484 not found in the top 1000 passages list.
Passage ID 8412682 not found in the top 1000 passages list.
Less than 3 relevant passages found for query ID 19335, only found 0.
Passage ID 14055 not found in the top 1000 passages list.
Passage ID 3841396 not found in the top 1000 passages list.
Passage ID 4658868 not found in the top 1000 passages list.
Less than 3 relevant passages found for query ID 131843, only found 0.
Passage ID 3883081 not found in the top 1000 passages list.
Less than 3 relevant passages found for query ID 182539, only found 2.
Passage ID 1247358 not found in the top 1000 passages list.
Passage ID 1289489 not found in the top 1000 passages list.
Less than 3 relevant passages found for query ID 183378, only found 1.
Passage ID 1029962 not found in the top 1000 passages list.
Less than 3 relevant passages found for query ID 207786, only found 2.
Passage ID 3440848 not found in the top 1000 pas

通过relevant_docs中的每个query的文档相关性评分，得到每个query相关性最高的前三个文档，然后passage_cand中找到文档的相关内容

In [9]:
!pip install bitsandbytes==0.43.1
!pip install transformers==4.40.2
!pip install peft==0.11.1
!pip install accelerate==0.30.1



In [10]:
import os
import json
import tqdm
import sys

使用生成模型进行查询扩展：

In [11]:
from transformers import AutoTokenizer, LlamaForCausalLM, AutoModelForCausalLM

model_name = "HuggingFaceH4/zephyr-7b-beta"
tokenizer = AutoTokenizer.from_pretrained(model_name, truncation=True, padding=True, padding_side="left", maximum_length = 2048, model_max_length = 2048)
model = AutoModelForCausalLM.from_pretrained(model_name, load_in_4bit = True, device_map = 'auto')
tokenizer.pad_token = tokenizer.eos_token
model.generation_config.pad_token_id = model.generation_config.eos_token_id

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): huggingface.co:443
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /HuggingFaceH4/zephyr-7b-beta/resolve/main/tokenizer_config.json HTTP/1.1" 200 0
DEBUG:filelock:Attempting to acquire lock 136992128840800 on /root/.cache/huggingface/hub/.locks/models--HuggingFaceH4--zephyr-7b-beta/c57f1bf6ea28d2e3ca4540709beb3a80815c2aab.lock
DEBUG:filelock:Lock 136992128840800 acquired on /root/.cache/huggingface/hub/.locks/models--HuggingFaceH4--zephyr-7b-beta/c57f1bf6ea28d2e3ca4540709beb3a80815c2aab.lock
DEBUG:urllib3

tokenizer_config.json:   0%|          | 0.00/1.43k [00:00<?, ?B/s]

DEBUG:filelock:Attempting to release lock 136992128840800 on /root/.cache/huggingface/hub/.locks/models--HuggingFaceH4--zephyr-7b-beta/c57f1bf6ea28d2e3ca4540709beb3a80815c2aab.lock
DEBUG:filelock:Lock 136992128840800 released on /root/.cache/huggingface/hub/.locks/models--HuggingFaceH4--zephyr-7b-beta/c57f1bf6ea28d2e3ca4540709beb3a80815c2aab.lock
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /HuggingFaceH4/zephyr-7b-beta/resolve/main/tokenizer.model HTTP/1.1" 302 0
DEBUG:filelock:Attempting to acquire lock 136992128445424 on /root/.cache/huggingface/hub/.locks/models--HuggingFaceH4--zephyr-7b-beta/dadfd56d766715c61d2ef780a525ab43b8e6da4de6865bda3d95fdef5e134055.lock
DEBUG:filelock:Lock 136992128445424 acquired on /root/.cache/huggingface/hub/.locks/models--HuggingFaceH4--zephyr-7b-beta/dadfd56d766715c61d2ef780a525ab43b8e6da4de6865bda3d95fdef5e134055.lock
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): cdn-lfs-us-1.huggingface.co:443
DEBUG:urllib3.connect

tokenizer.model:   0%|          | 0.00/493k [00:00<?, ?B/s]

DEBUG:filelock:Attempting to release lock 136992128445424 on /root/.cache/huggingface/hub/.locks/models--HuggingFaceH4--zephyr-7b-beta/dadfd56d766715c61d2ef780a525ab43b8e6da4de6865bda3d95fdef5e134055.lock
DEBUG:filelock:Lock 136992128445424 released on /root/.cache/huggingface/hub/.locks/models--HuggingFaceH4--zephyr-7b-beta/dadfd56d766715c61d2ef780a525ab43b8e6da4de6865bda3d95fdef5e134055.lock
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /HuggingFaceH4/zephyr-7b-beta/resolve/main/tokenizer.json HTTP/1.1" 200 0
DEBUG:filelock:Attempting to acquire lock 136992128445184 on /root/.cache/huggingface/hub/.locks/models--HuggingFaceH4--zephyr-7b-beta/43e6daf936dc0f953cb867ec864adab78f92d9ce.lock
DEBUG:filelock:Lock 136992128445184 acquired on /root/.cache/huggingface/hub/.locks/models--HuggingFaceH4--zephyr-7b-beta/43e6daf936dc0f953cb867ec864adab78f92d9ce.lock
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "GET /HuggingFaceH4/zephyr-7b-beta/resolve/main/tokenizer.json

tokenizer.json:   0%|          | 0.00/1.80M [00:00<?, ?B/s]

DEBUG:filelock:Attempting to release lock 136992128445184 on /root/.cache/huggingface/hub/.locks/models--HuggingFaceH4--zephyr-7b-beta/43e6daf936dc0f953cb867ec864adab78f92d9ce.lock
DEBUG:filelock:Lock 136992128445184 released on /root/.cache/huggingface/hub/.locks/models--HuggingFaceH4--zephyr-7b-beta/43e6daf936dc0f953cb867ec864adab78f92d9ce.lock
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /HuggingFaceH4/zephyr-7b-beta/resolve/main/added_tokens.json HTTP/1.1" 200 0
DEBUG:filelock:Attempting to acquire lock 136992128445424 on /root/.cache/huggingface/hub/.locks/models--HuggingFaceH4--zephyr-7b-beta/cbce74e5c64b97114098962fa58454a57d7fb532.lock
DEBUG:filelock:Lock 136992128445424 acquired on /root/.cache/huggingface/hub/.locks/models--HuggingFaceH4--zephyr-7b-beta/cbce74e5c64b97114098962fa58454a57d7fb532.lock
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "GET /HuggingFaceH4/zephyr-7b-beta/resolve/main/added_tokens.json HTTP/1.1" 200 42


added_tokens.json:   0%|          | 0.00/42.0 [00:00<?, ?B/s]

DEBUG:filelock:Attempting to release lock 136992128445424 on /root/.cache/huggingface/hub/.locks/models--HuggingFaceH4--zephyr-7b-beta/cbce74e5c64b97114098962fa58454a57d7fb532.lock
DEBUG:filelock:Lock 136992128445424 released on /root/.cache/huggingface/hub/.locks/models--HuggingFaceH4--zephyr-7b-beta/cbce74e5c64b97114098962fa58454a57d7fb532.lock
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /HuggingFaceH4/zephyr-7b-beta/resolve/main/special_tokens_map.json HTTP/1.1" 200 0
DEBUG:filelock:Attempting to acquire lock 136992128445424 on /root/.cache/huggingface/hub/.locks/models--HuggingFaceH4--zephyr-7b-beta/8cd5f1eb30d4e97d74cbf915c36db116aea5eca7.lock
DEBUG:filelock:Lock 136992128445424 acquired on /root/.cache/huggingface/hub/.locks/models--HuggingFaceH4--zephyr-7b-beta/8cd5f1eb30d4e97d74cbf915c36db116aea5eca7.lock
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "GET /HuggingFaceH4/zephyr-7b-beta/resolve/main/special_tokens_map.json HTTP/1.1" 200 168


special_tokens_map.json:   0%|          | 0.00/168 [00:00<?, ?B/s]

DEBUG:filelock:Attempting to release lock 136992128445424 on /root/.cache/huggingface/hub/.locks/models--HuggingFaceH4--zephyr-7b-beta/8cd5f1eb30d4e97d74cbf915c36db116aea5eca7.lock
DEBUG:filelock:Lock 136992128445424 released on /root/.cache/huggingface/hub/.locks/models--HuggingFaceH4--zephyr-7b-beta/8cd5f1eb30d4e97d74cbf915c36db116aea5eca7.lock
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /HuggingFaceH4/zephyr-7b-beta/resolve/main/config.json HTTP/1.1" 200 0
DEBUG:filelock:Attempting to acquire lock 136992126997680 on /root/.cache/huggingface/hub/.locks/models--HuggingFaceH4--zephyr-7b-beta/c3c550ad931ec57781ffa6f5bf4682e30bb6fbef.lock
DEBUG:filelock:Lock 136992126997680 acquired on /root/.cache/huggingface/hub/.locks/models--HuggingFaceH4--zephyr-7b-beta/c3c550ad931ec57781ffa6f5bf4682e30bb6fbef.lock
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "GET /HuggingFaceH4/zephyr-7b-beta/resolve/main/config.json HTTP/1.1" 200 638


config.json:   0%|          | 0.00/638 [00:00<?, ?B/s]

DEBUG:filelock:Attempting to release lock 136992126997680 on /root/.cache/huggingface/hub/.locks/models--HuggingFaceH4--zephyr-7b-beta/c3c550ad931ec57781ffa6f5bf4682e30bb6fbef.lock
DEBUG:filelock:Lock 136992126997680 released on /root/.cache/huggingface/hub/.locks/models--HuggingFaceH4--zephyr-7b-beta/c3c550ad931ec57781ffa6f5bf4682e30bb6fbef.lock
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /HuggingFaceH4/zephyr-7b-beta/resolve/main/adapter_config.json HTTP/1.1" 404 0
The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /HuggingFaceH4/zephyr-7b-beta/resolve/main/model.safetensors HTTP/1.1" 404 0
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /HuggingFaceH4/zephyr-7b-beta/resolve/main/model.safetensors.index.json HTTP/1.1" 200 0
DEBUG:filelock:Attempting to

model.safetensors.index.json:   0%|          | 0.00/23.9k [00:00<?, ?B/s]

DEBUG:filelock:Attempting to release lock 136992126997296 on /root/.cache/huggingface/hub/.locks/models--HuggingFaceH4--zephyr-7b-beta/fbc869b880f0c7287847c72de41d71522f62b685.lock
DEBUG:filelock:Lock 136992126997296 released on /root/.cache/huggingface/hub/.locks/models--HuggingFaceH4--zephyr-7b-beta/fbc869b880f0c7287847c72de41d71522f62b685.lock


Downloading shards:   0%|          | 0/8 [00:00<?, ?it/s]

DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /HuggingFaceH4/zephyr-7b-beta/resolve/main/model-00001-of-00008.safetensors HTTP/1.1" 302 0
DEBUG:filelock:Attempting to acquire lock 136992118109712 on /root/.cache/huggingface/hub/.locks/models--HuggingFaceH4--zephyr-7b-beta/a6ec4a0398df9d56786afa9cce026423ec48a07ddc5ed5eba087614cae2dd746.lock
DEBUG:filelock:Lock 136992118109712 acquired on /root/.cache/huggingface/hub/.locks/models--HuggingFaceH4--zephyr-7b-beta/a6ec4a0398df9d56786afa9cce026423ec48a07ddc5ed5eba087614cae2dd746.lock
DEBUG:urllib3.connectionpool:https://cdn-lfs-us-1.huggingface.co:443 "GET /repos/d9/3d/d93d0ae44e3930a5eb272129c6a12ccec827e219c3d5ba5474ae9ddf3b4b7647/a6ec4a0398df9d56786afa9cce026423ec48a07ddc5ed5eba087614cae2dd746?response-content-disposition=attachment%3B+filename*%3DUTF-8%27%27model-00001-of-00008.safetensors%3B+filename%3D%22model-00001-of-00008.safetensors%22%3B&Expires=1717012708&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1R

model-00001-of-00008.safetensors:   0%|          | 0.00/1.89G [00:00<?, ?B/s]

DEBUG:filelock:Attempting to release lock 136992118109712 on /root/.cache/huggingface/hub/.locks/models--HuggingFaceH4--zephyr-7b-beta/a6ec4a0398df9d56786afa9cce026423ec48a07ddc5ed5eba087614cae2dd746.lock
DEBUG:filelock:Lock 136992118109712 released on /root/.cache/huggingface/hub/.locks/models--HuggingFaceH4--zephyr-7b-beta/a6ec4a0398df9d56786afa9cce026423ec48a07ddc5ed5eba087614cae2dd746.lock
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /HuggingFaceH4/zephyr-7b-beta/resolve/main/model-00002-of-00008.safetensors HTTP/1.1" 302 0
DEBUG:filelock:Attempting to acquire lock 136992118110288 on /root/.cache/huggingface/hub/.locks/models--HuggingFaceH4--zephyr-7b-beta/1289796f16f33ef4c6b8a76b3d9e5169198a69daa0b5b660b5a3d5cae0dc1cf7.lock
DEBUG:filelock:Lock 136992118110288 acquired on /root/.cache/huggingface/hub/.locks/models--HuggingFaceH4--zephyr-7b-beta/1289796f16f33ef4c6b8a76b3d9e5169198a69daa0b5b660b5a3d5cae0dc1cf7.lock
DEBUG:urllib3.connectionpool:https://cdn-lfs-us-1.hu

model-00002-of-00008.safetensors:   0%|          | 0.00/1.95G [00:00<?, ?B/s]

DEBUG:filelock:Attempting to release lock 136992118110288 on /root/.cache/huggingface/hub/.locks/models--HuggingFaceH4--zephyr-7b-beta/1289796f16f33ef4c6b8a76b3d9e5169198a69daa0b5b660b5a3d5cae0dc1cf7.lock
DEBUG:filelock:Lock 136992118110288 released on /root/.cache/huggingface/hub/.locks/models--HuggingFaceH4--zephyr-7b-beta/1289796f16f33ef4c6b8a76b3d9e5169198a69daa0b5b660b5a3d5cae0dc1cf7.lock
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /HuggingFaceH4/zephyr-7b-beta/resolve/main/model-00003-of-00008.safetensors HTTP/1.1" 302 0
DEBUG:filelock:Attempting to acquire lock 136992118105248 on /root/.cache/huggingface/hub/.locks/models--HuggingFaceH4--zephyr-7b-beta/6fc338de96c672840222ca19a12b44f09110048974bde8ab0f4c1da055d99c3f.lock
DEBUG:filelock:Lock 136992118105248 acquired on /root/.cache/huggingface/hub/.locks/models--HuggingFaceH4--zephyr-7b-beta/6fc338de96c672840222ca19a12b44f09110048974bde8ab0f4c1da055d99c3f.lock
DEBUG:urllib3.connectionpool:https://cdn-lfs-us-1.hu

model-00003-of-00008.safetensors:   0%|          | 0.00/1.98G [00:00<?, ?B/s]

DEBUG:filelock:Attempting to release lock 136992118105248 on /root/.cache/huggingface/hub/.locks/models--HuggingFaceH4--zephyr-7b-beta/6fc338de96c672840222ca19a12b44f09110048974bde8ab0f4c1da055d99c3f.lock
DEBUG:filelock:Lock 136992118105248 released on /root/.cache/huggingface/hub/.locks/models--HuggingFaceH4--zephyr-7b-beta/6fc338de96c672840222ca19a12b44f09110048974bde8ab0f4c1da055d99c3f.lock
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /HuggingFaceH4/zephyr-7b-beta/resolve/main/model-00004-of-00008.safetensors HTTP/1.1" 302 0
DEBUG:filelock:Attempting to acquire lock 136992118109904 on /root/.cache/huggingface/hub/.locks/models--HuggingFaceH4--zephyr-7b-beta/92d948413fd7ac6c4fd5f6c36902bb7a72d2cddca19628c88e6f3b3e5482ab37.lock
DEBUG:filelock:Lock 136992118109904 acquired on /root/.cache/huggingface/hub/.locks/models--HuggingFaceH4--zephyr-7b-beta/92d948413fd7ac6c4fd5f6c36902bb7a72d2cddca19628c88e6f3b3e5482ab37.lock
DEBUG:urllib3.connectionpool:https://cdn-lfs-us-1.hu

model-00004-of-00008.safetensors:   0%|          | 0.00/1.95G [00:00<?, ?B/s]

DEBUG:filelock:Attempting to release lock 136992118109904 on /root/.cache/huggingface/hub/.locks/models--HuggingFaceH4--zephyr-7b-beta/92d948413fd7ac6c4fd5f6c36902bb7a72d2cddca19628c88e6f3b3e5482ab37.lock
DEBUG:filelock:Lock 136992118109904 released on /root/.cache/huggingface/hub/.locks/models--HuggingFaceH4--zephyr-7b-beta/92d948413fd7ac6c4fd5f6c36902bb7a72d2cddca19628c88e6f3b3e5482ab37.lock
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /HuggingFaceH4/zephyr-7b-beta/resolve/main/model-00005-of-00008.safetensors HTTP/1.1" 302 0
DEBUG:filelock:Attempting to acquire lock 136992118105200 on /root/.cache/huggingface/hub/.locks/models--HuggingFaceH4--zephyr-7b-beta/d0ce3cc66d224a0e6014a1df6b7e56da29f5db52da1866ad0cc07d7583fb7c31.lock
DEBUG:filelock:Lock 136992118105200 acquired on /root/.cache/huggingface/hub/.locks/models--HuggingFaceH4--zephyr-7b-beta/d0ce3cc66d224a0e6014a1df6b7e56da29f5db52da1866ad0cc07d7583fb7c31.lock
DEBUG:urllib3.connectionpool:https://cdn-lfs-us-1.hu

model-00005-of-00008.safetensors:   0%|          | 0.00/1.98G [00:00<?, ?B/s]

DEBUG:filelock:Attempting to release lock 136992118105200 on /root/.cache/huggingface/hub/.locks/models--HuggingFaceH4--zephyr-7b-beta/d0ce3cc66d224a0e6014a1df6b7e56da29f5db52da1866ad0cc07d7583fb7c31.lock
DEBUG:filelock:Lock 136992118105200 released on /root/.cache/huggingface/hub/.locks/models--HuggingFaceH4--zephyr-7b-beta/d0ce3cc66d224a0e6014a1df6b7e56da29f5db52da1866ad0cc07d7583fb7c31.lock
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /HuggingFaceH4/zephyr-7b-beta/resolve/main/model-00006-of-00008.safetensors HTTP/1.1" 302 0
DEBUG:filelock:Attempting to acquire lock 136992118109664 on /root/.cache/huggingface/hub/.locks/models--HuggingFaceH4--zephyr-7b-beta/0b63ef55219d125eaeefabd2a33c358dced22b9c75d1d816bb7a871bd3773951.lock
DEBUG:filelock:Lock 136992118109664 acquired on /root/.cache/huggingface/hub/.locks/models--HuggingFaceH4--zephyr-7b-beta/0b63ef55219d125eaeefabd2a33c358dced22b9c75d1d816bb7a871bd3773951.lock
DEBUG:urllib3.connectionpool:https://cdn-lfs-us-1.hu

model-00006-of-00008.safetensors:   0%|          | 0.00/1.95G [00:00<?, ?B/s]

DEBUG:filelock:Attempting to release lock 136992118109664 on /root/.cache/huggingface/hub/.locks/models--HuggingFaceH4--zephyr-7b-beta/0b63ef55219d125eaeefabd2a33c358dced22b9c75d1d816bb7a871bd3773951.lock
DEBUG:filelock:Lock 136992118109664 released on /root/.cache/huggingface/hub/.locks/models--HuggingFaceH4--zephyr-7b-beta/0b63ef55219d125eaeefabd2a33c358dced22b9c75d1d816bb7a871bd3773951.lock
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /HuggingFaceH4/zephyr-7b-beta/resolve/main/model-00007-of-00008.safetensors HTTP/1.1" 302 0
DEBUG:filelock:Attempting to acquire lock 136992118109664 on /root/.cache/huggingface/hub/.locks/models--HuggingFaceH4--zephyr-7b-beta/b0e0a3c0b992925ff9a60f1234950af0bcc7ce3015c8a386a342489e76f5d09c.lock
DEBUG:filelock:Lock 136992118109664 acquired on /root/.cache/huggingface/hub/.locks/models--HuggingFaceH4--zephyr-7b-beta/b0e0a3c0b992925ff9a60f1234950af0bcc7ce3015c8a386a342489e76f5d09c.lock
DEBUG:urllib3.connectionpool:https://cdn-lfs-us-1.hu

model-00007-of-00008.safetensors:   0%|          | 0.00/1.98G [00:00<?, ?B/s]

DEBUG:filelock:Attempting to release lock 136992118109664 on /root/.cache/huggingface/hub/.locks/models--HuggingFaceH4--zephyr-7b-beta/b0e0a3c0b992925ff9a60f1234950af0bcc7ce3015c8a386a342489e76f5d09c.lock
DEBUG:filelock:Lock 136992118109664 released on /root/.cache/huggingface/hub/.locks/models--HuggingFaceH4--zephyr-7b-beta/b0e0a3c0b992925ff9a60f1234950af0bcc7ce3015c8a386a342489e76f5d09c.lock
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /HuggingFaceH4/zephyr-7b-beta/resolve/main/model-00008-of-00008.safetensors HTTP/1.1" 302 0
DEBUG:filelock:Attempting to acquire lock 136992128444896 on /root/.cache/huggingface/hub/.locks/models--HuggingFaceH4--zephyr-7b-beta/667dd6958d5137dcba514db514c7970d18c9cd5d95fbb0c82490d558bd2a246c.lock
DEBUG:filelock:Lock 136992128444896 acquired on /root/.cache/huggingface/hub/.locks/models--HuggingFaceH4--zephyr-7b-beta/667dd6958d5137dcba514db514c7970d18c9cd5d95fbb0c82490d558bd2a246c.lock
DEBUG:urllib3.connectionpool:https://cdn-lfs-us-1.hu

model-00008-of-00008.safetensors:   0%|          | 0.00/816M [00:00<?, ?B/s]

DEBUG:filelock:Attempting to release lock 136992128444896 on /root/.cache/huggingface/hub/.locks/models--HuggingFaceH4--zephyr-7b-beta/667dd6958d5137dcba514db514c7970d18c9cd5d95fbb0c82490d558bd2a246c.lock
DEBUG:filelock:Lock 136992128444896 released on /root/.cache/huggingface/hub/.locks/models--HuggingFaceH4--zephyr-7b-beta/667dd6958d5137dcba514db514c7970d18c9cd5d95fbb0c82490d558bd2a246c.lock
DEBUG:bitsandbytes.cextension:Loading bitsandbytes native library from: /usr/local/lib/python3.10/dist-packages/bitsandbytes/libbitsandbytes_cuda121.so
INFO:accelerate.utils.modeling:We will use 90% of the memory on device 0 for storing the model, and 10% for the buffer to avoid OOM. You can set `max_memory` in to a higher value to use more memory (at your own risk).


Loading checkpoint shards:   0%|          | 0/8 [00:00<?, ?it/s]

DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /HuggingFaceH4/zephyr-7b-beta/resolve/main/generation_config.json HTTP/1.1" 200 0
DEBUG:filelock:Attempting to acquire lock 136992012670992 on /root/.cache/huggingface/hub/.locks/models--HuggingFaceH4--zephyr-7b-beta/38dd6f7cf521e4797c68803f67cfb1331d606353.lock
DEBUG:filelock:Lock 136992012670992 acquired on /root/.cache/huggingface/hub/.locks/models--HuggingFaceH4--zephyr-7b-beta/38dd6f7cf521e4797c68803f67cfb1331d606353.lock
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "GET /HuggingFaceH4/zephyr-7b-beta/resolve/main/generation_config.json HTTP/1.1" 200 111


generation_config.json:   0%|          | 0.00/111 [00:00<?, ?B/s]

DEBUG:filelock:Attempting to release lock 136992012670992 on /root/.cache/huggingface/hub/.locks/models--HuggingFaceH4--zephyr-7b-beta/38dd6f7cf521e4797c68803f67cfb1331d606353.lock
DEBUG:filelock:Lock 136992012670992 released on /root/.cache/huggingface/hub/.locks/models--HuggingFaceH4--zephyr-7b-beta/38dd6f7cf521e4797c68803f67cfb1331d606353.lock


In [12]:
import gzip
import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline, BertForSequenceClassification
from collections import defaultdict
import pytrec_eval
import numpy as np

generator = pipeline('text-generation', model=model, tokenizer=tokenizer)

def generate_expanded_queries(queries, top_docs):
    expanded_queries = {}
    for query_id in queries:
        if query_id in top_docs and len(top_docs[query_id]) == 3:
            PRF_doc_1, PRF_doc_2, PRF_doc_3 = top_docs[query_id]
            query_text = queries[query_id]
            input_text = (f"Write a list of keywords for the given query based on the context:\n"
                          f"Context: {PRF_doc_1}\n{PRF_doc_2}\n{PRF_doc_3}\nQuery: {query_text}\nKeywords:")

            expanded_query = generator(input_text, max_length=512, num_return_sequences=1)
            generated_text = expanded_query[0]['generated_text'].strip()

            concatenated_query = f"{query_text} {generated_text}"
            expanded_queries[query_id] = concatenated_query
            print("Expanded query:", concatenated_query)
        else:
            print(f"No enough docs for query {query_id}. Needed 3, got {len(top_docs.get(query_id, []))}")

    return expanded_queries

expanded_queries = generate_expanded_queries(queries, top_docs)

INFO:numexpr.utils:NumExpr defaulting to 2 threads.
DEBUG:tensorflow:Falling back to TensorFlow client; we recommended you install the Cloud TPU client directly with pip install cloud-tpu-client.
DEBUG:h5py._conv:Creating converter from 7 to 5
DEBUG:h5py._conv:Creating converter from 5 to 7
DEBUG:h5py._conv:Creating converter from 7 to 5
DEBUG:h5py._conv:Creating converter from 5 to 7
Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.


No enough docs for query 1108939. Needed 3, got 0
No enough docs for query 1112389. Needed 3, got 0
No enough docs for query 792752. Needed 3, got 0
No enough docs for query 1119729. Needed 3, got 0
No enough docs for query 1105095. Needed 3, got 0
No enough docs for query 1105103. Needed 3, got 0
No enough docs for query 1128373. Needed 3, got 0
No enough docs for query 1127622. Needed 3, got 0
No enough docs for query 1124979. Needed 3, got 0
No enough docs for query 885490. Needed 3, got 0
No enough docs for query 1119827. Needed 3, got 0
No enough docs for query 190044. Needed 3, got 0
No enough docs for query 500575. Needed 3, got 0
No enough docs for query 883785. Needed 3, got 0
No enough docs for query 264403. Needed 3, got 0
No enough docs for query 1108100. Needed 3, got 0
No enough docs for query 421756. Needed 3, got 0
No enough docs for query 1108307. Needed 3, got 0
No enough docs for query 966413. Needed 3, got 0
No enough docs for query 1111546. Needed 3, got 0




Expanded query: do goldfish grow Write a list of keywords for the given query based on the context:
Context: A: The conditions goldfish are kept in plus their diet determine how large they will grow. I have seen goldfish grow ridiculously large in very small containers when their water was changed frequently. Goldfish will not grow if water conditions are poor. Fancy goldfish don’t grow as large as Common goldfish. A good size would be around 5 inches body length for most fancy varieties, 8 inches for Comets and 12 inches for Common Goldfish. These sizes are usually only attained by pond grown fish.
Goldfish can live for many years if properly cared for, a bowl is not large enough. My goldfish lived 10 years in a 27 gallon tank but may have lived longer in a pond. 1 liter won't be enough - goldfish can grow to 10 inches long or even larger than that. Thanks guys I was so worried but not anymore.
Shubunkin goldfish can grow to be 12 inches, so according to this rule, you should give eac

You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a dataset


Expanded query: when was the salvation army founded Write a list of keywords for the given query based on the context:
Context: William Booth. William Booth (April 10, 1829 – August 20,1912) was the founder and first General (1878-1912) of The Salvation Army. Originally a Methodist lay preacher, William Booth saw tremendous need not being fulfilled by mainstream churches in Victorian England.n his honor, Vachel Lindsay wrote the poem General William Booth Enters Heaven, and Charles Ives, who had been Booth's neighbor, set it to music. William Booth was succeeded by his son, Bramwell Booth, who became the second General of the Salvation Army, serving from 1912 to 1929.
William Booth (April 10, 1829 – August 20,1912) was the founder and first General (1878-1912) of The Salvation Army.Originally a Methodist lay preacher, William Booth saw tremendous need not being fulfilled by mainstream churches in Victorian England.n his honor, Vachel Lindsay wrote the poem General William Booth Enters 

In [13]:
expanded_queries

{'156493': "do goldfish grow Write a list of keywords for the given query based on the context:\nContext: A: The conditions goldfish are kept in plus their diet determine how large they will grow. I have seen goldfish grow ridiculously large in very small containers when their water was changed frequently. Goldfish will not grow if water conditions are poor. Fancy goldfish don’t grow as large as Common goldfish. A good size would be around 5 inches body length for most fancy varieties, 8 inches for Comets and 12 inches for Common Goldfish. These sizes are usually only attained by pond grown fish.\nGoldfish can live for many years if properly cared for, a bowl is not large enough. My goldfish lived 10 years in a 27 gallon tank but may have lived longer in a pond. 1 liter won't be enough - goldfish can grow to 10 inches long or even larger than that. Thanks guys I was so worried but not anymore.\nShubunkin goldfish can grow to be 12 inches, so according to this rule, you should give each

In [14]:
import pickle
from transformers import AutoTokenizer
def truncate_queries(queries, tokenizer, max_length=512):
    truncated_queries = {}
    for qid, query in queries.items():
        tokens = tokenizer.encode(query, truncation=True, max_length=max_length)
        truncated_query = tokenizer.decode(tokens, skip_special_tokens=True)
        truncated_queries[str(qid)] = truncated_query
    return truncated_queries

sim_model_name = 'bert-base-uncased'
sim_tokenizer = AutoTokenizer.from_pretrained(sim_model_name)

truncated_expanded_queries = truncate_queries(expanded_queries, sim_tokenizer, max_length=512)

with open('truncated_expanded_queries.pkl', 'wb') as f:
    pickle.dump(truncated_expanded_queries, f)

print("Truncated expanded queries have been saved to truncated_expanded_queries.pkl")



DEBUG:urllib3.connectionpool:Resetting dropped connection: huggingface.co
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /bert-base-uncased/resolve/main/tokenizer_config.json HTTP/1.1" 200 0
DEBUG:filelock:Attempting to acquire lock 136991695437008 on /root/.cache/huggingface/hub/.locks/models--bert-base-uncased/e5c73d8a50df1f56fb5b0b8002d7cf4010afdccb.lock
DEBUG:filelock:Lock 136991695437008 acquired on /root/.cache/huggingface/hub/.locks/models--bert-base-uncased/e5c73d8a50df1f56fb5b0b8002d7cf4010afdccb.lock
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "GET /bert-base-uncased/resolve/main/tokenizer_config.json HTTP/1.1" 200 48


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

DEBUG:filelock:Attempting to release lock 136991695437008 on /root/.cache/huggingface/hub/.locks/models--bert-base-uncased/e5c73d8a50df1f56fb5b0b8002d7cf4010afdccb.lock
DEBUG:filelock:Lock 136991695437008 released on /root/.cache/huggingface/hub/.locks/models--bert-base-uncased/e5c73d8a50df1f56fb5b0b8002d7cf4010afdccb.lock
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /bert-base-uncased/resolve/main/config.json HTTP/1.1" 200 0
DEBUG:filelock:Attempting to acquire lock 136992118333280 on /root/.cache/huggingface/hub/.locks/models--bert-base-uncased/45a2321a7ecfdaaf60a6c1fd7f5463994cc8907d.lock
DEBUG:filelock:Lock 136992118333280 acquired on /root/.cache/huggingface/hub/.locks/models--bert-base-uncased/45a2321a7ecfdaaf60a6c1fd7f5463994cc8907d.lock
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "GET /bert-base-uncased/resolve/main/config.json HTTP/1.1" 200 570


config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

DEBUG:filelock:Attempting to release lock 136992118333280 on /root/.cache/huggingface/hub/.locks/models--bert-base-uncased/45a2321a7ecfdaaf60a6c1fd7f5463994cc8907d.lock
DEBUG:filelock:Lock 136992118333280 released on /root/.cache/huggingface/hub/.locks/models--bert-base-uncased/45a2321a7ecfdaaf60a6c1fd7f5463994cc8907d.lock
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /bert-base-uncased/resolve/main/vocab.txt HTTP/1.1" 200 0
DEBUG:filelock:Attempting to acquire lock 136992118333472 on /root/.cache/huggingface/hub/.locks/models--bert-base-uncased/fb140275c155a9c7c5a3b3e0e77a9e839594a938.lock
DEBUG:filelock:Lock 136992118333472 acquired on /root/.cache/huggingface/hub/.locks/models--bert-base-uncased/fb140275c155a9c7c5a3b3e0e77a9e839594a938.lock
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "GET /bert-base-uncased/resolve/main/vocab.txt HTTP/1.1" 200 231508


vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

DEBUG:filelock:Attempting to release lock 136992118333472 on /root/.cache/huggingface/hub/.locks/models--bert-base-uncased/fb140275c155a9c7c5a3b3e0e77a9e839594a938.lock
DEBUG:filelock:Lock 136992118333472 released on /root/.cache/huggingface/hub/.locks/models--bert-base-uncased/fb140275c155a9c7c5a3b3e0e77a9e839594a938.lock
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /bert-base-uncased/resolve/main/tokenizer.json HTTP/1.1" 200 0
DEBUG:filelock:Attempting to acquire lock 136992118330496 on /root/.cache/huggingface/hub/.locks/models--bert-base-uncased/949a6f013d67eb8a5b4b5b46026217b888021b88.lock
DEBUG:filelock:Lock 136992118330496 acquired on /root/.cache/huggingface/hub/.locks/models--bert-base-uncased/949a6f013d67eb8a5b4b5b46026217b888021b88.lock
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "GET /bert-base-uncased/resolve/main/tokenizer.json HTTP/1.1" 200 466062


tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

DEBUG:filelock:Attempting to release lock 136992118330496 on /root/.cache/huggingface/hub/.locks/models--bert-base-uncased/949a6f013d67eb8a5b4b5b46026217b888021b88.lock
DEBUG:filelock:Lock 136992118330496 released on /root/.cache/huggingface/hub/.locks/models--bert-base-uncased/949a6f013d67eb8a5b4b5b46026217b888021b88.lock
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /bert-base-uncased/resolve/main/added_tokens.json HTTP/1.1" 404 0
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /bert-base-uncased/resolve/main/special_tokens_map.json HTTP/1.1" 404 0


Truncated expanded queries have been saved to truncated_expanded_queries.pkl


## Evaluating preparation

### Initialize hyperparameters (e.g., batch size, etc)

In [25]:
from google.colab import drive
drive.mount('/content/gdrive')
base_path = "./gdrive/MyDrive/cross-encoder-reranker-ir-course-2023/"

Mounted at /content/gdrive


In [26]:
!mkdir -p $base_path

## Evaluate the model


### Load the fine-tuned model that you trained using the previous notebook. You need to set the path of your own fine-tuned model here.

In [54]:
model_save_path = "/content/gdrive/MyDrive/cross-encoder-reranker-ir-course-2023/finetuned_models/cross-encoder-distilbert-distilroberta-base-2024-05-12_07-39-41" #@param {type:"string"}

### Load data (For evaluation on TREC DL'19)

In [55]:
!wget https://msmarco.z22.web.core.windows.net/msmarcoranking/queries.tar.gz
!tar -xvzf  queries.tar.gz

--2024-05-26 20:54:48--  https://msmarco.z22.web.core.windows.net/msmarcoranking/queries.tar.gz
Resolving msmarco.z22.web.core.windows.net (msmarco.z22.web.core.windows.net)... 20.150.34.1
Connecting to msmarco.z22.web.core.windows.net (msmarco.z22.web.core.windows.net)|20.150.34.1|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 18882551 (18M) [application/gzip]
Saving to: ‘queries.tar.gz.7’


2024-05-26 20:54:49 (41.5 MB/s) - ‘queries.tar.gz.7’ saved [18882551/18882551]

queries.dev.tsv
queries.eval.tsv
queries.train.tsv


In [56]:
import pickle
"""
This file evaluates CrossEncoder on the TREC 2019 Deep Learning (DL) Track: https://arxiv.org/abs/2003.07820

TREC 2019 DL is based on the corpus of MS Marco. MS Marco provides a sparse annotation, i.e., usually only a single
passage is marked as relevant for a given query. Many other highly relevant passages are not annotated and hence are treated
as an error if a model ranks those high.

TREC DL instead annotated up to 200 passages per query for their relevance to a given query. It is better suited to estimate
the model performance for the task of reranking in Information Retrieval.

Run:
python eval_cross-encoder-trec-dl.py cross-encoder-model-name

"""


data_folder = 'trec2019-data'
os.makedirs(data_folder, exist_ok=True)

#Read test queries
# queries = {}
# queries_filepath = os.path.join(data_folder, 'msmarco-test2019-queries.tsv.gz')
# if not os.path.exists(queries_filepath):
#     logging.info("Download "+os.path.basename(queries_filepath))
#     util.http_get('https://msmarco.z22.web.core.windows.net/msmarcoranking/msmarco-test2019-queries.tsv.gz', queries_filepath)

# with gzip.open(queries_filepath, 'rt', encoding='utf8') as fIn:
#     for line in fIn:
#         qid, query = line.strip().split("\t")
#         queries[qid] = query
with open('truncated_expanded_queries.pkl', 'rb') as f:
    queries = pickle.load(f)


#Read which passages are relevant
relevant_docs = defaultdict(lambda: defaultdict(int))
qrels_filepath = os.path.join(data_folder, '2019qrels-pass.txt')

if not os.path.exists(qrels_filepath):
    logging.info("Download "+os.path.basename(qrels_filepath))
    util.http_get('https://trec.nist.gov/data/deep/2019qrels-pass.txt', qrels_filepath)


with open(qrels_filepath) as fIn:
    for line in fIn:
        qid, _, pid, score = line.strip().split()
        score = int(score)
        if score > 0:
            relevant_docs[qid][pid] = score

# Only use queries that have at least one relevant passage
relevant_qid = []
for qid in queries:
    if len(relevant_docs[qid]) > 0:
        relevant_qid.append(qid)


# Read the top 1000 passages that are supposed to be re-ranked
passage_filepath = os.path.join(data_folder, 'msmarco-passagetest2019-top1000.tsv.gz')

if not os.path.exists(passage_filepath):
    logging.info("Download "+os.path.basename(passage_filepath))
    util.http_get('https://msmarco.z22.web.core.windows.net/msmarcoranking/msmarco-passagetest2019-top1000.tsv.gz', passage_filepath)



passage_cand = {}
with gzip.open(passage_filepath, 'rt', encoding='utf8') as fIn:
    for line in fIn:
        qid, pid, query, passage = line.strip().split("\t")
        if qid not in passage_cand:
            passage_cand[qid] = []

        passage_cand[qid].append([pid, passage])

logging.info("Queries: {}".format(len(queries)))


INFO:root:Queries: 26


## Prediction

### Mini

In [37]:
queries_result_list = []
run = {}
model = CrossEncoder(model_save_path, max_length=512)

for qid in tqdm.tqdm(relevant_qid):
    query = queries[qid]

    cand = passage_cand[qid]
    pids = [c[0] for c in cand]
    corpus_sentences = [c[1] for c in cand]

    cross_inp = [[query, sent] for sent in corpus_sentences]

    if model.config.num_labels > 1: #Cross-Encoder that predict more than 1 score, we use the last and apply softmax
        cross_scores = model.predict(cross_inp, apply_softmax=True)[:, 1].tolist()
    else:
        cross_scores = model.predict(cross_inp).tolist()

    cross_scores_sparse = {}
    for idx, pid in enumerate(pids):
        cross_scores_sparse[pid] = cross_scores[idx]

    sparse_scores = cross_scores_sparse
    run[qid] = {}
    for pid in sparse_scores:
        run[qid][pid] = float(sparse_scores[pid])

INFO:sentence_transformers.cross_encoder.CrossEncoder:Use pytorch device: cuda
  0%|          | 0/26 [00:00<?, ?it/s]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

  4%|▍         | 1/26 [00:02<00:54,  2.19s/it]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

  8%|▊         | 2/26 [00:04<00:47,  1.98s/it]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 12%|█▏        | 3/26 [00:05<00:45,  1.96s/it]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 15%|█▌        | 4/26 [00:08<00:51,  2.34s/it]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 19%|█▉        | 5/26 [00:10<00:45,  2.18s/it]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 23%|██▎       | 6/26 [00:13<00:45,  2.27s/it]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 27%|██▋       | 7/26 [00:15<00:44,  2.34s/it]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 31%|███       | 8/26 [00:18<00:43,  2.39s/it]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 35%|███▍      | 9/26 [00:20<00:41,  2.46s/it]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 38%|███▊      | 10/26 [00:23<00:41,  2.58s/it]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 42%|████▏     | 11/26 [00:26<00:38,  2.56s/it]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 46%|████▌     | 12/26 [00:28<00:35,  2.56s/it]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 50%|█████     | 13/26 [00:31<00:33,  2.57s/it]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 54%|█████▍    | 14/26 [00:33<00:30,  2.51s/it]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 58%|█████▊    | 15/26 [00:35<00:25,  2.32s/it]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 62%|██████▏   | 16/26 [00:38<00:24,  2.48s/it]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 65%|██████▌   | 17/26 [00:40<00:21,  2.42s/it]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 69%|██████▉   | 18/26 [00:43<00:19,  2.42s/it]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

 73%|███████▎  | 19/26 [00:43<00:12,  1.74s/it]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 77%|███████▋  | 20/26 [00:45<00:11,  1.88s/it]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 85%|████████▍ | 22/26 [00:47<00:06,  1.54s/it]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 88%|████████▊ | 23/26 [00:50<00:05,  1.88s/it]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 92%|█████████▏| 24/26 [00:52<00:03,  1.90s/it]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 96%|█████████▌| 25/26 [00:54<00:01,  1.95s/it]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

100%|██████████| 26/26 [00:57<00:00,  2.20s/it]


### Distil

In [59]:
queries_result_list = []
run = {}
model = CrossEncoder(model_save_path, max_length=512)

for qid in tqdm.tqdm(relevant_qid):
    query = queries[qid]

    cand = passage_cand[qid]
    pids = [c[0] for c in cand]
    corpus_sentences = [c[1] for c in cand]

    cross_inp = [[query, sent] for sent in corpus_sentences]

    if model.config.num_labels > 1: #Cross-Encoder that predict more than 1 score, we use the last and apply softmax
        cross_scores = model.predict(cross_inp, apply_softmax=True)[:, 1].tolist()
    else:
        cross_scores = model.predict(cross_inp).tolist()

    cross_scores_sparse = {}
    for idx, pid in enumerate(pids):
        cross_scores_sparse[pid] = cross_scores[idx]

    sparse_scores = cross_scores_sparse
    run[qid] = {}
    for pid in sparse_scores:
        run[qid][pid] = float(sparse_scores[pid])

INFO:sentence_transformers.cross_encoder.CrossEncoder:Use pytorch device: cuda
  0%|          | 0/26 [00:00<?, ?it/s]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

  4%|▍         | 1/26 [00:15<06:25, 15.44s/it]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

  8%|▊         | 2/26 [00:28<05:38, 14.10s/it]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 12%|█▏        | 3/26 [00:39<04:54, 12.80s/it]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 15%|█▌        | 4/26 [00:57<05:20, 14.59s/it]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 19%|█▉        | 5/26 [01:11<05:04, 14.49s/it]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 23%|██▎       | 6/26 [01:28<05:06, 15.32s/it]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 27%|██▋       | 7/26 [01:45<05:01, 15.88s/it]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 31%|███       | 8/26 [02:02<04:54, 16.37s/it]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 35%|███▍      | 9/26 [02:20<04:42, 16.62s/it]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 38%|███▊      | 10/26 [02:37<04:28, 16.79s/it]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 42%|████▏     | 11/26 [02:54<04:13, 16.90s/it]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 46%|████▌     | 12/26 [03:11<03:58, 17.02s/it]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 50%|█████     | 13/26 [03:28<03:42, 17.11s/it]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 54%|█████▍    | 14/26 [03:44<03:17, 16.50s/it]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 58%|█████▊    | 15/26 [03:54<02:42, 14.79s/it]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 62%|██████▏   | 16/26 [04:12<02:35, 15.53s/it]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 65%|██████▌   | 17/26 [04:29<02:23, 15.94s/it]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 69%|██████▉   | 18/26 [04:44<02:06, 15.87s/it]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

 73%|███████▎  | 19/26 [04:45<01:19, 11.30s/it]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 77%|███████▋  | 20/26 [05:00<01:14, 12.39s/it]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 85%|████████▍ | 22/26 [05:14<00:40, 10.00s/it]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 88%|████████▊ | 23/26 [05:31<00:35, 11.77s/it]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 92%|█████████▏| 24/26 [05:46<00:24, 12.47s/it]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 96%|█████████▌| 25/26 [06:02<00:13, 13.35s/it]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

100%|██████████| 26/26 [06:19<00:00, 14.58s/it]


### Tiny

In [52]:
queries_result_list = []
run = {}
model = CrossEncoder(model_save_path, max_length=512)

for qid in tqdm.tqdm(relevant_qid):
    query = queries[qid]

    cand = passage_cand[qid]
    pids = [c[0] for c in cand]
    corpus_sentences = [c[1] for c in cand]

    cross_inp = [[query, sent] for sent in corpus_sentences]

    if model.config.num_labels > 1: #Cross-Encoder that predict more than 1 score, we use the last and apply softmax
        cross_scores = model.predict(cross_inp, apply_softmax=True)[:, 1].tolist()
    else:
        cross_scores = model.predict(cross_inp).tolist()

    cross_scores_sparse = {}
    for idx, pid in enumerate(pids):
        cross_scores_sparse[pid] = cross_scores[idx]

    sparse_scores = cross_scores_sparse
    run[qid] = {}
    for pid in sparse_scores:
        run[qid][pid] = float(sparse_scores[pid])

INFO:sentence_transformers.cross_encoder.CrossEncoder:Use pytorch device: cuda
  0%|          | 0/26 [00:00<?, ?it/s]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

  4%|▍         | 1/26 [00:02<00:55,  2.22s/it]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

  8%|▊         | 2/26 [00:04<00:52,  2.20s/it]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 12%|█▏        | 3/26 [00:05<00:41,  1.80s/it]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 15%|█▌        | 4/26 [00:07<00:39,  1.77s/it]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 19%|█▉        | 5/26 [00:08<00:33,  1.60s/it]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 23%|██▎       | 6/26 [00:10<00:32,  1.64s/it]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 27%|██▋       | 7/26 [00:11<00:29,  1.57s/it]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 31%|███       | 8/26 [00:13<00:29,  1.61s/it]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 35%|███▍      | 9/26 [00:15<00:27,  1.61s/it]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 38%|███▊      | 10/26 [00:18<00:31,  1.98s/it]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 42%|████▏     | 11/26 [00:20<00:32,  2.17s/it]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 46%|████▌     | 12/26 [00:22<00:29,  2.09s/it]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 50%|█████     | 13/26 [00:24<00:26,  2.07s/it]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 54%|█████▍    | 14/26 [00:25<00:22,  1.87s/it]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 58%|█████▊    | 15/26 [00:27<00:17,  1.62s/it]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 62%|██████▏   | 16/26 [00:28<00:16,  1.66s/it]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 65%|██████▌   | 17/26 [00:30<00:14,  1.65s/it]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 69%|██████▉   | 18/26 [00:32<00:15,  1.88s/it]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

 73%|███████▎  | 19/26 [00:32<00:09,  1.36s/it]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 77%|███████▋  | 20/26 [00:35<00:09,  1.61s/it]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 85%|████████▍ | 22/26 [00:36<00:04,  1.18s/it]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 88%|████████▊ | 23/26 [00:38<00:03,  1.31s/it]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 92%|█████████▏| 24/26 [00:39<00:02,  1.26s/it]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 96%|█████████▌| 25/26 [00:40<00:01,  1.30s/it]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

100%|██████████| 26/26 [00:42<00:00,  1.62s/it]


## Evaluation

### Mini

In [38]:
evaluator = pytrec_eval.RelevanceEvaluator(relevant_docs, {'ndcg_cut.10', 'recall_100', 'map_cut.1000'})
scores = evaluator.evaluate(run)

print("Queries:", len(relevant_qid))
print("NDCG@10: {:.2f}".format(np.mean([ele["ndcg_cut_10"] for ele in scores.values()])*100))
print("Recall@100: {:.2f}".format(np.mean([ele["recall_100"] for ele in scores.values()])*100))
print("MAP@1000: {:.2f}".format(np.mean([ele["map_cut_1000"] for ele in scores.values()])*100))

Queries: 26
NDCG@10: 65.07
Recall@100: 50.03
MAP@1000: 44.45


### Distil

In [60]:
evaluator = pytrec_eval.RelevanceEvaluator(relevant_docs, {'ndcg_cut.10', 'recall_100', 'map_cut.1000'})
scores = evaluator.evaluate(run)

print("Queries:", len(relevant_qid))
print("NDCG@10: {:.2f}".format(np.mean([ele["ndcg_cut_10"] for ele in scores.values()])*100))
print("Recall@100: {:.2f}".format(np.mean([ele["recall_100"] for ele in scores.values()])*100))
print("MAP@1000: {:.2f}".format(np.mean([ele["map_cut_1000"] for ele in scores.values()])*100))

Queries: 26
NDCG@10: 11.13
Recall@100: 15.93
MAP@1000: 12.85


### Tiny

In [53]:
evaluator = pytrec_eval.RelevanceEvaluator(relevant_docs, {'ndcg_cut.10', 'recall_100', 'map_cut.1000'})
scores = evaluator.evaluate(run)

print("Queries:", len(relevant_qid))
print("NDCG@10: {:.2f}".format(np.mean([ele["ndcg_cut_10"] for ele in scores.values()])*100))
print("Recall@100: {:.2f}".format(np.mean([ele["recall_100"] for ele in scores.values()])*100))
print("MAP@1000: {:.2f}".format(np.mean([ele["map_cut_1000"] for ele in scores.values()])*100))

Queries: 26
NDCG@10: 70.20
Recall@100: 54.64
MAP@1000: 49.39
