In [None]:
!pip install sentence-transformers

In [None]:
!pip install pytrec_eval

In [None]:
"""
This examples show how to train a Cross-Encoder for the MS Marco dataset (https://github.com/microsoft/MSMARCO-Passage-Ranking).

The query and the passage are passed simoultanously to a Transformer network. The network then returns
a score between 0 and 1 how relevant the passage is for a given query.

The resulting Cross-Encoder can then be used for passage re-ranking: You retrieve for example 100 passages
for a given query, for example with ElasticSearch, and pass the query+retrieved_passage to the CrossEncoder
for scoring. You sort the results then according to the output of the CrossEncoder.

This gives a significant boost compared to out-of-the-box ElasticSearch / BM25 ranking.
"""
from torch.utils.data import DataLoader
from sentence_transformers import LoggingHandler, util
from sentence_transformers.cross_encoder import CrossEncoder
from sentence_transformers.cross_encoder.evaluation import CERerankingEvaluator
from sentence_transformers import InputExample
from datetime import datetime
import gzip
import os
import tarfile
import tqdm
import logging
from collections import defaultdict
import numpy as np
import sys
import pytrec_eval
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
logging.basicConfig(format='%(asctime)s - %(message)s',datefmt='%Y-%m-%d %H:%M:%S')

In [None]:
!wget https://msmarco.z22.web.core.windows.net/msmarcoranking/queries.tar.gz
!tar -xvzf  queries.tar.gz

读取三个文件

In [None]:
"""
This file evaluates CrossEncoder on the TREC 2019 Deep Learning (DL) Track: https://arxiv.org/abs/2003.07820

TREC 2019 DL is based on the corpus of MS Marco. MS Marco provides a sparse annotation, i.e., usually only a single
passage is marked as relevant for a given query. Many other highly relevant passages are not annotated and hence are treated
as an error if a model ranks those high.

TREC DL instead annotated up to 200 passages per query for their relevance to a given query. It is better suited to estimate
the model performance for the task of reranking in Information Retrieval.

Run:
python eval_cross-encoder-trec-dl.py cross-encoder-model-name

"""


data_folder = 'trec2019-data'
os.makedirs(data_folder, exist_ok=True)

#Read test queries
queries = {}
queries_filepath = os.path.join(data_folder, 'msmarco-test2019-queries.tsv.gz')
if not os.path.exists(queries_filepath):
    logging.info("Download "+os.path.basename(queries_filepath))
    util.http_get('https://msmarco.z22.web.core.windows.net/msmarcoranking/msmarco-test2019-queries.tsv.gz', queries_filepath)

with gzip.open(queries_filepath, 'rt', encoding='utf8') as fIn:
    for line in fIn:
        qid, query = line.strip().split("\t")
        queries[qid] = query

#Read which passages are relevant
relevant_docs = defaultdict(lambda: defaultdict(int))
qrels_filepath = os.path.join(data_folder, '2019qrels-pass.txt')

if not os.path.exists(qrels_filepath):
    logging.info("Download "+os.path.basename(qrels_filepath))
    util.http_get('https://trec.nist.gov/data/deep/2019qrels-pass.txt', qrels_filepath)


with open(qrels_filepath) as fIn:
    for line in fIn:
        qid, _, pid, score = line.strip().split()
        score = int(score)
        if score > 0:
            relevant_docs[qid][pid] = score

# Only use queries that have at least one relevant passage
relevant_qid = []
for qid in queries:
    if len(relevant_docs[qid]) > 0:
        relevant_qid.append(qid)


# Read the top 1000 passages that are supposed to be re-ranked
passage_filepath = os.path.join(data_folder, 'msmarco-passagetest2019-top1000.tsv.gz')

if not os.path.exists(passage_filepath):
    logging.info("Download "+os.path.basename(passage_filepath))
    util.http_get('https://msmarco.z22.web.core.windows.net/msmarcoranking/msmarco-passagetest2019-top1000.tsv.gz', passage_filepath)



passage_cand = {}
with gzip.open(passage_filepath, 'rt', encoding='utf8') as fIn:
    for line in fIn:
        qid, pid, query, passage = line.strip().split("\t")
        if qid not in passage_cand:
            passage_cand[qid] = []

        passage_cand[qid].append([pid, passage])

logging.info("Queries: {}".format(len(queries)))

print("queries\n", queries)
print("relevant_docs\n", relevant_docs)
print("relevant_qid\n", relevant_qid)
print("passage_cand\n", passage_cand)

In [None]:
for qid in relevant_docs:
    if qid in passage_cand:
        top_docs = sorted(relevant_docs[qid].items(), key=lambda x: x[1], reverse=True)[:3]
        print(f"Top 3 relevant passages for query ID {qid}:")
        for pid, score in top_docs:
            for p in passage_cand[qid]:
                if p[0] == pid:
                    print(f"PID: {pid}, Score: {score}, Passage: {p[1]}")
                    break

In [None]:
for qid in relevant_docs:
    if qid in passage_cand:
        top_docs = sorted(relevant_docs[qid].items(), key=lambda x: x[1], reverse=True)[:3]
        print(f"Top 3 relevant passages for query ID {qid}:")
        found_passage = False
        for pid, score in top_docs:
            passage_found = False
            for p in passage_cand[qid]:
                if p[0] == pid:
                    print(f"PID: {pid}, Score: {score}, Passage: {p[1]}")
                    found_passage = True
                    passage_found = True
                    break
            if not passage_found:
                print(f"Passage ID {pid} not found in the top 1000 passages list.")
        if not found_passage:
            print(f"No relevant passages found or printed for query ID {qid}.")

In [None]:
# Initialize a dictionary
top_docs = {}

for qid in relevant_docs:
    if qid in passage_cand:
        top_three_docs = sorted(relevant_docs[qid].items(), key=lambda x: x[1], reverse=True)[:3]
        top_docs_text = []

        for pid, score in top_three_docs:
            passage_found = False
            for p in passage_cand[qid]:
                if p[0] == pid:
                    top_docs_text.append(p[1])
                    passage_found = True
                    break
            if not passage_found:
                print(f"Passage ID {pid} not found in the top 1000 passages list.")

        if len(top_docs_text) == 3:
            top_docs[qid] = top_docs_text
        else:
            print(f"Less than 3 relevant passages found for query ID {qid}, only found {len(top_docs_text)}.")

通过relevant_docs中的每个query的文档相关性评分，得到每个query相关性最高的前三个文档，然后passage_cand中找到文档的相关内容

In [None]:
!pip install bitsandbytes==0.43.1
!pip install transformers==4.40.2
!pip install peft==0.11.1
!pip install accelerate==0.30.1

In [None]:
import os
import json
import tqdm
import sys

使用生成模型进行查询扩展：

In [None]:
from transformers import AutoTokenizer, LlamaForCausalLM, AutoModelForCausalLM

model_name = "HuggingFaceH4/zephyr-7b-beta"
tokenizer = AutoTokenizer.from_pretrained(model_name, truncation=True, padding=True, padding_side="left", maximum_length = 2048, model_max_length = 2048)
model = AutoModelForCausalLM.from_pretrained(model_name, load_in_4bit = True, device_map = 'auto')
tokenizer.pad_token = tokenizer.eos_token
model.generation_config.pad_token_id = model.generation_config.eos_token_id

In [None]:
import gzip
import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline, BertForSequenceClassification
from collections import defaultdict
import pytrec_eval
import numpy as np

generator = pipeline('text-generation', model=model, tokenizer=tokenizer)

def generate_expanded_queries(queries, top_docs):
    expanded_queries = {}
    for query_id in queries:
        if query_id in top_docs and len(top_docs[query_id]) == 3:
            PRF_doc_1, PRF_doc_2, PRF_doc_3 = top_docs[query_id]
            query_text = queries[query_id]
            input_text = (f"Write a list of keywords for the given query based on the context:\n"
                          f"Context: {PRF_doc_1}\n{PRF_doc_2}\n{PRF_doc_3}\nQuery: {query_text}\nKeywords:")

            expanded_query = generator(input_text, max_length=512, num_return_sequences=1)
            generated_text = expanded_query[0]['generated_text'].strip()

            concatenated_query = f"{query_text} {generated_text}"
            expanded_queries[query_id] = concatenated_query
            print("Expanded query:", concatenated_query)
        else:
            print(f"No enough docs for query {query_id}. Needed 3, got {len(top_docs.get(query_id, []))}")

    return expanded_queries

expanded_queries = generate_expanded_queries(queries, top_docs)

In [None]:
expanded_queries

In [None]:
import pickle
from transformers import AutoTokenizer
def truncate_queries(queries, tokenizer, max_length=512):
    truncated_queries = {}
    for qid, query in queries.items():
        tokens = tokenizer.encode(query, truncation=True, max_length=max_length)
        truncated_query = tokenizer.decode(tokens, skip_special_tokens=True)
        truncated_queries[str(qid)] = truncated_query
    return truncated_queries

sim_model_name = 'bert-base-uncased'
sim_tokenizer = AutoTokenizer.from_pretrained(sim_model_name)

truncated_expanded_queries = truncate_queries(expanded_queries, sim_tokenizer, max_length=512)

with open('truncated_expanded_queries.pkl', 'wb') as f:
    pickle.dump(truncated_expanded_queries, f)

print("Truncated expanded queries have been saved to truncated_expanded_queries.pkl")



## Evaluating preparation

### Initialize hyperparameters (e.g., batch size, etc)

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')
base_path = "./gdrive/MyDrive/cross-encoder-reranker-ir-course-2023/"

In [None]:
!mkdir -p $base_path

## Evaluate the model


### Load the fine-tuned model that you trained using the previous notebook. You need to set the path of your own fine-tuned model here.

In [None]:
model_save_path = "/content/gdrive/MyDrive/cross-encoder-reranker-ir-course-2023/finetuned_models/cross-encoder-distilbert-distilroberta-base-2024-05-12_07-39-41" #@param {type:"string"}

### Load data (For evaluation on TREC DL'19)

In [None]:
!wget https://msmarco.z22.web.core.windows.net/msmarcoranking/queries.tar.gz
!tar -xvzf  queries.tar.gz

In [None]:
import pickle
"""
This file evaluates CrossEncoder on the TREC 2019 Deep Learning (DL) Track: https://arxiv.org/abs/2003.07820

TREC 2019 DL is based on the corpus of MS Marco. MS Marco provides a sparse annotation, i.e., usually only a single
passage is marked as relevant for a given query. Many other highly relevant passages are not annotated and hence are treated
as an error if a model ranks those high.

TREC DL instead annotated up to 200 passages per query for their relevance to a given query. It is better suited to estimate
the model performance for the task of reranking in Information Retrieval.

Run:
python eval_cross-encoder-trec-dl.py cross-encoder-model-name

"""


data_folder = 'trec2019-data'
os.makedirs(data_folder, exist_ok=True)

#Read test queries
# queries = {}
# queries_filepath = os.path.join(data_folder, 'msmarco-test2019-queries.tsv.gz')
# if not os.path.exists(queries_filepath):
#     logging.info("Download "+os.path.basename(queries_filepath))
#     util.http_get('https://msmarco.z22.web.core.windows.net/msmarcoranking/msmarco-test2019-queries.tsv.gz', queries_filepath)

# with gzip.open(queries_filepath, 'rt', encoding='utf8') as fIn:
#     for line in fIn:
#         qid, query = line.strip().split("\t")
#         queries[qid] = query
with open('truncated_expanded_queries.pkl', 'rb') as f:
    queries = pickle.load(f)


#Read which passages are relevant
relevant_docs = defaultdict(lambda: defaultdict(int))
qrels_filepath = os.path.join(data_folder, '2019qrels-pass.txt')

if not os.path.exists(qrels_filepath):
    logging.info("Download "+os.path.basename(qrels_filepath))
    util.http_get('https://trec.nist.gov/data/deep/2019qrels-pass.txt', qrels_filepath)


with open(qrels_filepath) as fIn:
    for line in fIn:
        qid, _, pid, score = line.strip().split()
        score = int(score)
        if score > 0:
            relevant_docs[qid][pid] = score

# Only use queries that have at least one relevant passage
relevant_qid = []
for qid in queries:
    if len(relevant_docs[qid]) > 0:
        relevant_qid.append(qid)


# Read the top 1000 passages that are supposed to be re-ranked
passage_filepath = os.path.join(data_folder, 'msmarco-passagetest2019-top1000.tsv.gz')

if not os.path.exists(passage_filepath):
    logging.info("Download "+os.path.basename(passage_filepath))
    util.http_get('https://msmarco.z22.web.core.windows.net/msmarcoranking/msmarco-passagetest2019-top1000.tsv.gz', passage_filepath)



passage_cand = {}
with gzip.open(passage_filepath, 'rt', encoding='utf8') as fIn:
    for line in fIn:
        qid, pid, query, passage = line.strip().split("\t")
        if qid not in passage_cand:
            passage_cand[qid] = []

        passage_cand[qid].append([pid, passage])

logging.info("Queries: {}".format(len(queries)))


## Prediction

### Mini

In [None]:
queries_result_list = []
run = {}
model = CrossEncoder(model_save_path, max_length=512)

for qid in tqdm.tqdm(relevant_qid):
    query = queries[qid]

    cand = passage_cand[qid]
    pids = [c[0] for c in cand]
    corpus_sentences = [c[1] for c in cand]

    cross_inp = [[query, sent] for sent in corpus_sentences]

    if model.config.num_labels > 1: #Cross-Encoder that predict more than 1 score, we use the last and apply softmax
        cross_scores = model.predict(cross_inp, apply_softmax=True)[:, 1].tolist()
    else:
        cross_scores = model.predict(cross_inp).tolist()

    cross_scores_sparse = {}
    for idx, pid in enumerate(pids):
        cross_scores_sparse[pid] = cross_scores[idx]

    sparse_scores = cross_scores_sparse
    run[qid] = {}
    for pid in sparse_scores:
        run[qid][pid] = float(sparse_scores[pid])

### Distil

In [None]:
queries_result_list = []
run = {}
model = CrossEncoder(model_save_path, max_length=512)

for qid in tqdm.tqdm(relevant_qid):
    query = queries[qid]

    cand = passage_cand[qid]
    pids = [c[0] for c in cand]
    corpus_sentences = [c[1] for c in cand]

    cross_inp = [[query, sent] for sent in corpus_sentences]

    if model.config.num_labels > 1: #Cross-Encoder that predict more than 1 score, we use the last and apply softmax
        cross_scores = model.predict(cross_inp, apply_softmax=True)[:, 1].tolist()
    else:
        cross_scores = model.predict(cross_inp).tolist()

    cross_scores_sparse = {}
    for idx, pid in enumerate(pids):
        cross_scores_sparse[pid] = cross_scores[idx]

    sparse_scores = cross_scores_sparse
    run[qid] = {}
    for pid in sparse_scores:
        run[qid][pid] = float(sparse_scores[pid])

### Tiny

In [None]:
queries_result_list = []
run = {}
model = CrossEncoder(model_save_path, max_length=512)

for qid in tqdm.tqdm(relevant_qid):
    query = queries[qid]

    cand = passage_cand[qid]
    pids = [c[0] for c in cand]
    corpus_sentences = [c[1] for c in cand]

    cross_inp = [[query, sent] for sent in corpus_sentences]

    if model.config.num_labels > 1: #Cross-Encoder that predict more than 1 score, we use the last and apply softmax
        cross_scores = model.predict(cross_inp, apply_softmax=True)[:, 1].tolist()
    else:
        cross_scores = model.predict(cross_inp).tolist()

    cross_scores_sparse = {}
    for idx, pid in enumerate(pids):
        cross_scores_sparse[pid] = cross_scores[idx]

    sparse_scores = cross_scores_sparse
    run[qid] = {}
    for pid in sparse_scores:
        run[qid][pid] = float(sparse_scores[pid])

## Evaluation

### Mini

In [None]:
evaluator = pytrec_eval.RelevanceEvaluator(relevant_docs, {'ndcg_cut.10', 'recall_100', 'map_cut.1000'})
scores = evaluator.evaluate(run)

print("Queries:", len(relevant_qid))
print("NDCG@10: {:.2f}".format(np.mean([ele["ndcg_cut_10"] for ele in scores.values()])*100))
print("Recall@100: {:.2f}".format(np.mean([ele["recall_100"] for ele in scores.values()])*100))
print("MAP@1000: {:.2f}".format(np.mean([ele["map_cut_1000"] for ele in scores.values()])*100))

### Distil

In [None]:
evaluator = pytrec_eval.RelevanceEvaluator(relevant_docs, {'ndcg_cut.10', 'recall_100', 'map_cut.1000'})
scores = evaluator.evaluate(run)

print("Queries:", len(relevant_qid))
print("NDCG@10: {:.2f}".format(np.mean([ele["ndcg_cut_10"] for ele in scores.values()])*100))
print("Recall@100: {:.2f}".format(np.mean([ele["recall_100"] for ele in scores.values()])*100))
print("MAP@1000: {:.2f}".format(np.mean([ele["map_cut_1000"] for ele in scores.values()])*100))

### Tiny

In [None]:
evaluator = pytrec_eval.RelevanceEvaluator(relevant_docs, {'ndcg_cut.10', 'recall_100', 'map_cut.1000'})
scores = evaluator.evaluate(run)

print("Queries:", len(relevant_qid))
print("NDCG@10: {:.2f}".format(np.mean([ele["ndcg_cut_10"] for ele in scores.values()])*100))
print("Recall@100: {:.2f}".format(np.mean([ele["recall_100"] for ele in scores.values()])*100))
print("MAP@1000: {:.2f}".format(np.mean([ele["map_cut_1000"] for ele in scores.values()])*100))