Skip to content

Commit

Permalink
feat: peft training for llm embed finetuning
Browse files Browse the repository at this point in the history
  • Loading branch information
LongxingTan committed Jun 22, 2024
1 parent cc5f5fa commit fb6f376
Show file tree
Hide file tree
Showing 18 changed files with 274 additions and 160 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@

**Open-retrievals** simplifies text embeddings, retrievals, ranking, and RAG using PyTorch and Transformers. This user-friendly framework is designed for information retrieval and LLM generation.
- Embeddings, retrieval and rerank all-in-one: `AutoModelForEmbedding`
- Contrastive learning/LLM enhanced embeddings, with point-wise, pairwise and listwise training
- Cross-encoder, ColBERT and LLM enhanced Reranking
- Fast RAG demo integrated with Langchain and LlamaIndex
- Contrastive learning/LLM enhanced embeddings, with point-wise, pairwise and listwise fine-tuning
- Cross-encoder, ColBERT and LLM reranker
- Fast RAG easily integrated with Langchain and LlamaIndex

![structure](./docs/source/_static/structure.png)

Expand Down
4 changes: 2 additions & 2 deletions README_zh-CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@

**Open-Retrievals** 帮助开发者在信息检索、大语言模型等领域便捷地应用文本向量,快速搭建检索、排序、RAG等应用。
- `AutoModelForEmbedding`一统向量、检索、重排
- 多种对比学习、point-wise、pairwise、listwise微调向量模型、rerank模型
- 定制化、或集成Langchain、LlamaIndex快速产出RAG demo
- 支持向量与重排模型多种微调方式,对比学习、大模型、point-wise、pairwise、listwise
- 定制化RAG框架,也支持在Langchain、LlamaIndex中便捷使用微调后的模型

![structure](./docs/source/_static/structure.png)

Expand Down
33 changes: 33 additions & 0 deletions docs/source/rerank.rst
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,36 @@ Fine tuning ColBERT

Fine tuning LLM ranker
----------------------------

- Point-wise style prompt:

"Passage: {text}\nPlease write a question based on this passage."

- Point-wise style prompt:

"Passage: {text}\nQuery: {query}\nDoes the passage answer the query? Answer 'Yes' or 'No'"

- pairwise style prompt:

"""Given a query "{query}", which of the following two passages is more relevant to the query?

Passage A: "{doc1}"

Passage B: "{doc2}"

Output Passage A or Passage B:"""

- listwise style prompt:

I will provide you with {num} passages, each indicated by number identifier []. \nRank the passages based on their relevance to query: {query}."

- set-wise style prompt:

Given a query "{query}", which of the following passages is the most relevant one to the query?\n\n' \
+ passages + '\n\nOutput only the passage label of the most relevant passage:'


Reference
-------------------

- https://github.com/ielab/llm-rankers/tree/main
41 changes: 41 additions & 0 deletions examples/1_retrieval/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,44 @@ torchrun --nproc_per_node 1 \
--temperature 0.02 \
--use_inbatch_neg false
```


If you want to finetune a LLM for embedding:

- add query_instruction
- "Given a query and a relevant document, retrieve the document that are pertinent to the query\nQuery: "
- use the appropriate pooling_method
- last
- maybe reduce the batch_size due to large model size
- set use_lora to True if you want to use lora

```shell
MODEL_NAME="intfloat/e5-mistral-7b-instruct"
TRAIN_DATA="/t2_ranking.jsonl"
OUTPUT_DIR="/t2_output"

torchrun --nproc_per_node 1 \
-m retrievals.pipelines.embed \
--output_dir $OUTPUT_DIR \
--overwrite_output_dir \
--model_name_or_path $MODEL_NAME \
--do_train \
--train_data $TRAIN_DATA \
--positive_key positive \
--negative_key negative \
--use_lora True \
--query_instruction "Given a query and a relevant document, retrieve the document that are pertinent to the query\nQuery: " \
--document_instruction '# Document: ' \
--learning_rate 3e-5 \
--bf16 \
--num_train_epochs 5 \
--per_device_train_batch_size 2 \
--gradient_accumulation_steps 1 \
--dataloader_drop_last True \
--query_max_length 128 \
--document_max_length 256 \
--train_group_size 2 \
--logging_steps 100 \
--temperature 0.02 \
--use_inbatch_neg false
```
35 changes: 35 additions & 0 deletions examples/1_retrieval/eval/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# 评测


```python
from typing import List, Union, Dict
import numpy as np
from retrievals import AutoModelForEmbedding


class AutoModelForEmbeddingEval(AutoModelForEmbedding):
def __init__(self, **kwargs):
super(AutoModelForEmbeddingEval, self).__init__()

def encode_queries(self, queries: List[str], **kwargs) -> np.ndarray:
"""For MTEB eval
This function will be used for retrieval task
if there is an instruction for queries, we will add it to the query text
"""
if self.query_instruction is not None:
input_texts = ['{}{}'.format(self.query_instruction, q) for q in queries]
else:
input_texts = queries
return self.encode_from_text(input_texts, batch_size=4)

def encode_corpus(self, corpus: List[Union[Dict[str, str], str]], **kwargs) -> np.ndarray:
"""For MTEB eval
This function will be used for retrieval task
encode corpus for retrieval task
"""
if isinstance(corpus[0], dict):
input_texts = ['{} {}'.format(doc.get('title', ''), doc['text']).strip() for doc in corpus]
else:
input_texts = corpus
return self.encode_from_text(input_texts, batch_size=4)
```
11 changes: 8 additions & 3 deletions examples/2_rerank/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@ ds_train.to_json("t2_ranking.jsonl", force_ascii=False)
```

## train
- cross encoder
cross encoder
```shell
python train_cross_encoder.py
```

- cross encoder

```shell
MODEL_NAME="BAAI/bge-reranker-base"
TRAIN_DATA="/t2_ranking.jsonl"
Expand All @@ -48,7 +48,12 @@ torchrun --nproc_per_node 1 \
```


- colbert
colbert
```shell
python train_colbert.py
```


LLM
- AutoModelForRanking.from_pretrained(model_name_or_path, causal_lm = True)
- Prompt: "Given a query with a relevant body, determine whether the document is pertinent to the query by providing a prediction of either 'Yes' or 'No'."
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
numpy
numpy<=1.26.0
torch
transformers
accelerator
peft
datasets
faiss-cpu==1.8.0
faiss-cpu
scikit-learn
tqdm
1 change: 0 additions & 1 deletion src/retrievals/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from .data.collator import (
ColBertCollator,
LLMRerankCollator,
PairCollator,
RerankCollator,
TripletCollator,
Expand Down
41 changes: 0 additions & 41 deletions src/retrievals/data/collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,47 +209,6 @@ def __call__(self, features: Union[List[Dict[str, Any]], List]) -> BatchEncoding
return batch


class LLMRerankCollator(DataCollatorForSeq2Seq):
"""Rerank collator for casual llm, with examples query, positive and negative"""

query_max_length: int = 32
document_max_length: int = 128
query_instruction: Optional[str] = None
document_instruction: Optional[str] = None

def __call__(self, features: List[Dict[str, Any]], return_tensors='pt'):
if return_tensors is None:
return_tensors = self.return_tensors

if isinstance(features[0], list):
features = sum(features, [])

labels = [feature["labels"] for feature in features] if "labels" in features[0].keys() else None
if labels is not None:
max_label_length = max(len(l) for l in labels)

padding_side = self.tokenizer.padding_side
for feature in features:
remainder = [self.label_pad_token_id] * (max_label_length - len(feature["labels"]))
if isinstance(feature["labels"], list):
feature["labels"] = (
feature["labels"] + remainder if padding_side == "right" else remainder + feature["labels"]
)
elif padding_side == "right":
feature["labels"] = np.concatenate([feature["labels"], remainder]).astype(np.int64)
else:
feature["labels"] = np.concatenate([remainder, feature["labels"]]).astype(np.int64)

collated = self.tokenizer.pad(
features,
padding=self.padding,
max_length=self.query_max_length + self.document_max_length,
return_tensors=return_tensors,
pad_to_multiple_of=self.pad_to_multiple_of,
)
return collated


class ColBertCollator(DataCollatorWithPadding):
def __init__(
self,
Expand Down
Loading

0 comments on commit fb6f376

Please sign in to comment.