Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add mteb evaluation #8538

Merged
merged 5 commits into from
Jun 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 88 additions & 4 deletions pipelines/examples/contrastive_training/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@

## 安装

推荐安装gpu版本的[PaddlePaddle](https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/pip/linux-pip.html),以cuda11.7的paddle为例,安装命令如下:
推荐安装gpu版本的[PaddlePaddle](https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/conda/linux-conda.html),以cuda11.7的paddle为例,安装命令如下:

```
python -m pip install paddlepaddle-gpu==2.6.0.post117 -f https://www.paddlepaddle.org.cn/whl/linux/mkl/avx/stable.html
conda install nccl -c conda-forge
conda install paddlepaddle-gpu==2.6.1 -c https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/Paddle/ -c conda-forge
```
安装其他依赖:
```
Expand Down Expand Up @@ -98,15 +99,98 @@ python evaluation/benchmarks.py --model_type bert \
--passage_model checkpoints/checkpoint-1500 \
--query_max_length 64 \
--passage_max_length 512 \
--evaluate_all
```
- `model_type`: 模型的类似,可选bert或roberta等等
- `query_model`: query向量模型的路径
- `passage_model`: passage向量模型的路径
- `query_max_length`: query的最大长度
- `passage_max_length`: passage的最大长度
- `evaluate_all`: 是否评估所有的checkpoint,默认为False,即只评估指定的checkpoint
- `checkpoint_dir`: 与`evaluate_all`一起使用


## MTEB评估
[MTEB](https://github.com/embeddings-benchmark/mteb)
是一个大规模文本嵌入评测基准,包含了丰富的向量检索评估任务和数据集。
本仓库主要面向其中的中英文检索任务(Retrieval),并以SciFact数据集作为主要示例。

评估RepLLaMA向量检索模型([repllama-v1-7b-lora-passage](https://huggingface.co/castorini/repllama-v1-7b-lora-passage)):
```
export CUDA_VISIBLE_DEVICES=0
python evaluation/mteb/eval_mteb.py \
--base_model_name_or_path castorini/repllama-v1-7b-lora-passage \
--output_folder en_results/repllama-v1-7b-lora-passage \
--task_name SciFact \
--task_split test \
--query_instruction 'query: ' \
--document_instruction 'passage: ' \
--pooling_method last \
--max_seq_length 512 \
--eval_batch_size 2 \
--pad_token unk_token \
--padding_side right \
--add_bos_token 0 \
--add_eos_token 1
```
结果文件保存在`en_results/repllama-v1-7b-lora-passage/SciFact/last/no_revision_available/SciFact.json`,包含以下类似的评估结果:
```
'ndcg_at_1': 0.63,
'ndcg_at_3': 0.71785,
'ndcg_at_5': 0.73735,
'ndcg_at_10': 0.75708,
'ndcg_at_20': 0.7664,
'ndcg_at_100': 0.77394,
'ndcg_at_1000': 0.7794
```

评估BGE向量检索模型([bge-large-en-v1.5](https://huggingface.co/BAAI/bge-large-en-v1.5)):
```
export CUDA_VISIBLE_DEVICES=0
python evaluation/mteb/eval_mteb.py \
--base_model_name_or_path BAAI/bge-large-en-v1.5 \
--output_folder en_results/bge-large-en-v1.5 \
--task_name SciFact \
--task_split test \
--document_instruction 'Represent this sentence for searching relevant passages: ' \
--pooling_method mean \
--max_seq_length 512 \
--eval_batch_size 32 \
--pad_token pad_token \
--padding_side right \
--add_bos_token 0 \
--add_eos_token 0
```
结果文件保存在`en_results/bge-large-en-v1.5/SciFact/mean/no_revision_available/SciFact.json`,包含以下类似的评估结果:
```
'ndcg_at_1': 0.64667,
'ndcg_at_3': 0.70359,
'ndcg_at_5': 0.7265,
'ndcg_at_10': 0.75675,
'ndcg_at_20': 0.76743,
'ndcg_at_100': 0.77511,
'ndcg_at_1000': 0.77939
```

可支持配置的参数:
- `base_model_name_or_path`: 模型名称或路径
- `output_folder`: 结果文件存储路径
- `task_name`:任务(数据集)名称,如SciFact
- `task_split`:测试查询集合,如test或dev
- `query_instruction`:查询前添加的提示文本,如'query: '或None
- `document_instruction`:文档前添加的提示文本,如'passage: '或None
- `pooling_method`:获取表示的方式,last表示取最后token,mean表示取平均,cls表示取`[CLS]`token
- `max_seq_length`: 最大序列长度
- `eval_batch_size`: 模型预测的批次大小(单个GPU)
- `pad_token`:设置padding的token,可取unk_token、eos_token或pad_token
- `padding_side`:设置padding的位置,可取left或right
- `add_bos_token`:是否添加起始符,0表示不添加,1表示添加
- `add_eos_token`:是否添加结束符,0表示不添加,1表示添加


## Reference

[1] Aditya Kusupati, Gantavya Bhatt, Aniket Rege, Matthew Wallingford, Aditya Sinha, Vivek Ramanujan, William Howard-Snyder, Kaifeng Chen, Sham M. Kakade, Prateek Jain, Ali Farhadi: Matryoshka Representation Learning. NeurIPS 2022
[1] Aditya Kusupati, Gantavya Bhatt, Aniket Rege, Matthew Wallingford, Aditya Sinha, Vivek Ramanujan, William Howard-Snyder, Kaifeng Chen, Sham M. Kakade, Prateek Jain, Ali Farhadi: Matryoshka Representation Learning. NeurIPS 2022.

[2] Xueguang Ma, Liang Wang, Nan Yang, Furu Wei, Jimmy Lin: Fine-Tuning LLaMA for Multi-Stage Text Retrieval. arXiv 2023.

[3] Shitao Xiao, Zheng Liu, Peitian Zhang, Niklas Muennighof: C-Pack: Packaged Resources To Advance General Chinese Embedding. SIGIR 2024.
107 changes: 107 additions & 0 deletions pipelines/examples/contrastive_training/evaluation/mteb/eval_mteb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import logging

from mteb import MTEB
from mteb_models import EncodeModel

from paddlenlp.transformers import AutoModel, AutoTokenizer


def get_model(peft_model_name, base_model_name):
if peft_model_name is not None:
raise NotImplementedError("PEFT model is not supported yet")
else:
base_model = AutoModel.from_pretrained(base_model_name)
return base_model


def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--base_model_name_or_path", default="bge-large-en-v1.5", type=str)
parser.add_argument("--peft_model_name_or_path", default=None, type=str)
parser.add_argument("--output_folder", default="tmp", type=str)

parser.add_argument("--task_name", default="SciFact", type=str)
parser.add_argument(
"--task_split",
default="test",
help='Note that some datasets do not have "test", they only have "dev"',
type=str,
)

parser.add_argument("--query_instruction", default=None, help="add prefix instruction before query", type=str)
parser.add_argument(
"--document_instruction", default=None, help="add prefix instruction before document", type=str
)

parser.add_argument("--pooling_method", default="last", help="choose in [mean, last, cls]", type=str)
parser.add_argument("--max_seq_length", default=512, type=int)
parser.add_argument("--eval_batch_size", default=1, type=int)

parser.add_argument("--pad_token", default="unk_token", help="unk_token, eos_token or pad_token", type=str)
parser.add_argument("--padding_side", default="left", help="right or left", type=str)
parser.add_argument("--add_bos_token", default=0, help="1 means add token", type=int)
parser.add_argument("--add_eos_token", default=1, help="1 means add token", type=int)

return parser.parse_args()


if __name__ == "__main__":
args = get_args()

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
logger.info("Args: {}".format(args))

model = get_model(args.peft_model_name_or_path, args.base_model_name_or_path)

tokenizer = AutoTokenizer.from_pretrained(args.base_model_name_or_path)
assert hasattr(tokenizer, args.pad_token), f"Tokenizer does not have {args.pad_token} token"
token_dict = {"unk_token": tokenizer.unk_token, "eos_token": tokenizer.eos_token, "pad_token": tokenizer.pad_token}
tokenizer.pad_token = token_dict[args.pad_token]

assert args.padding_side in [
"right",
"left",
], f"padding_side should be either 'right' or 'left', but got {args.padding_side}"
assert not (
args.padding_side == "left" and args.pooling_method == "cls"
), "Padding 'left' is not supported for pooling method 'cls'"
tokenizer.padding_side = args.padding_side

assert args.add_bos_token in [0, 1], f"add_bos_token should be either 0 or 1, but got {args.add_bos_token}"
assert args.add_eos_token in [0, 1], f"add_eos_token should be either 0 or 1, but got {args.add_eos_token}"
tokenizer.add_bos_token = bool(args.add_bos_token)
tokenizer.add_eos_token = bool(args.add_eos_token)

encode_model = EncodeModel(
model=model,
tokenizer=tokenizer,
pooling_method=args.pooling_method,
query_instruction=args.query_instruction,
document_instruction=args.document_instruction,
eval_batch_size=args.eval_batch_size,
max_seq_length=args.max_seq_length,
)

logger.info("Ready to eval")
evaluation = MTEB(tasks=[args.task_name])
evaluation.run(
encode_model,
output_folder=f"{args.output_folder}/{args.task_name}/{args.pooling_method}",
eval_splits=[args.task_split],
)
127 changes: 127 additions & 0 deletions pipelines/examples/contrastive_training/evaluation/mteb/mteb_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, List, Union

import numpy as np
import paddle
from tqdm import tqdm


class EncodeModel:
def __init__(
self,
model,
tokenizer,
pooling_method: str = "last",
query_instruction: str = None,
document_instruction: str = None,
eval_batch_size: int = 64,
max_seq_length: int = 512,
):
self.model = model
self.tokenizer = tokenizer
self.pooling_method = pooling_method
self.query_instruction = query_instruction
self.document_instruction = document_instruction
self.eval_batch_size = eval_batch_size
self.max_seq_length = max_seq_length

if paddle.device.is_compiled_with_cuda():
self.device = paddle.device.set_device("gpu")
else:
self.device = paddle.device.set_device("cpu")
self.model = self.model.to(self.device)

num_gpus = paddle.device.cuda.device_count()
if num_gpus > 1:
raise NotImplementedError("Multi-GPU is not supported yet.")

def encode_queries(self, queries: List[str], **kwargs) -> np.ndarray:
"""
This function will be used to encode queries for retrieval task
if there is a instruction for queries, we will add it to the query text
"""
if self.query_instruction is not None:
input_texts = [f"{self.query_instruction}{query}" for query in queries]
else:
input_texts = queries
return self.encode(input_texts)

def encode_corpus(self, corpus: List[Union[Dict[str, str], str]], **kwargs) -> np.ndarray:
"""
This function will be used to encode corpus for retrieval task
if there is a instruction for docs, we will add it to the doc text
"""
if isinstance(corpus[0], dict):
if self.document_instruction is not None:
input_texts = [
"{}{} {}".format(self.document_instruction, doc.get("title", ""), doc["text"]).strip()
for doc in corpus
]
else:
input_texts = ["{} {}".format(doc.get("title", ""), doc["text"]).strip() for doc in corpus]
else:
if self.document_instruction is not None:
input_texts = [f"{self.document_instruction}{doc}" for doc in corpus]
else:
input_texts = corpus
return self.encode(input_texts)

@paddle.no_grad()
def encode(self, sentences: List[str], **kwargs) -> np.ndarray:
self.model.eval()
all_embeddings = []
for start_index in tqdm(range(0, len(sentences), self.eval_batch_size), desc="Batches"):
sentences_batch = sentences[start_index : start_index + self.eval_batch_size]

inputs = self.tokenizer(
sentences_batch,
padding=True,
truncation=True,
return_tensors="pd",
max_length=self.max_seq_length,
return_attention_mask=True,
)
outputs = self.model(
input_ids=inputs.input_ids,
attention_mask=inputs.attention_mask,
return_dict=True,
output_hidden_states=True,
)
last_hidden_state = outputs.hidden_states[-1]

if self.pooling_method == "last":
if self.tokenizer.padding_side == "right":
sequence_lengths = inputs.attention_mask.sum(axis=1)
last_token_indices = sequence_lengths - 1
embeddings = last_hidden_state[paddle.arange(last_hidden_state.shape[0]), last_token_indices]
elif self.tokenizer.padding_side == "left":
embeddings = last_hidden_state[:, -1]
else:
raise NotImplementedError(f"Padding side {self.tokenizer.padding_side} not supported.")
elif self.pooling_method == "cls":
embeddings = last_hidden_state[:, 1]
elif self.pooling_method == "mean":
s = paddle.sum(last_hidden_state * inputs.attention_mask.unsqueeze(-1), axis=1)
d = inputs.attention_mask.sum(axis=1, keepdim=True)
embeddings = s / d
else:
raise NotImplementedError(f"Pooling method {self.pooling_method} not supported.")

embeddings = paddle.nn.functional.normalize(embeddings, p=2, axis=-1)

all_embeddings.append(embeddings.cpu().numpy().astype("float32"))

return np.concatenate(all_embeddings, axis=0)
3 changes: 2 additions & 1 deletion pipelines/examples/contrastive_training/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
paddlenlp>2.6.1
datasets
torch==2.0.1
mteb[beir]
mteb
beir
typer==0.9.0
Loading