Skip to content

Commit

Permalink
fix: rag benchmark (#57)
Browse files Browse the repository at this point in the history
  • Loading branch information
LongxingTan committed Jun 13, 2024
1 parent cdcdc8f commit e405abc
Show file tree
Hide file tree
Showing 29 changed files with 632 additions and 159 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=False)
model = AutoModelForEmbedding.from_pretrained(model_name_or_path, pooling_method="cls")
# model = model.set_train_type('pointwise') # 'pointwise', 'pairwise', 'listwise'
optimizer = AdamW(model.parameters(), lr=5e-5)
num_train_steps=int(len(train_dataset) / batch_size * epochs)
num_train_steps = int(len(train_dataset) / batch_size * epochs)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0.05 * num_train_steps, num_training_steps=num_train_steps)

training_arguments = TrainingArguments(
Expand Down
2 changes: 1 addition & 1 deletion README_ja-JP.md
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=False)
model = AutoModelForEmbedding.from_pretrained(model_name_or_path, pooling_method="cls")
# model = model.set_train_type('pointwise') # 'pointwise', 'pairwise', 'listwise'
optimizer = AdamW(model.parameters(), lr=5e-5)
num_train_steps=int(len(train_dataset) / batch_size * epochs)
num_train_steps = int(len(train_dataset) / batch_size * epochs)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0.05 * num_train_steps, num_training_steps=num_train_steps)

training_arguments = TrainingArguments(
Expand Down
29 changes: 16 additions & 13 deletions README_zh-CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
**Open-Retrievals** 帮助开发者在信息检索、大语言模型等领域便捷地应用文本向量,快速搭建检索、排序、RAG等应用。
- `AutoModelForEmbedding`一统向量、检索、重排
- 多种对比学习、point-wise、pairwise、listwise微调向量模型、rerank模型
- 集成Langchain、LlamaIndex快速产出RAG demo
- 定制化、或集成Langchain、LlamaIndex快速产出RAG demo

![structure](./docs/source/_static/structure.png)

Expand Down Expand Up @@ -59,7 +59,7 @@ python -m pip install -U git+https://github.com/LongxingTan/open-retrievals.git

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1-WBMisdWLeHUKlzJ2DrREXY_kSV8vjP3?usp=sharing)

**使用预训练权重的文本向量**
**向量:使用预训练权重**
```python
from retrievals import AutoModelForEmbedding

Expand All @@ -77,7 +77,7 @@ scores = (embeddings[:2] @ embeddings[2:].T) * 100
print(scores.tolist())
```

**使用Faiss向量数据库检索**
**检索:使用Faiss向量数据库**
```python
from retrievals import AutoModelForEmbedding, AutoModelForRetrieval

Expand All @@ -93,7 +93,7 @@ dists, indices = matcher.similarity_search(query_embed, index_path=index_path)
print(indices)
```

**重排**
**重排:使用预训练权重**
```python
from retrievals import AutoModelForRanking

Expand All @@ -106,7 +106,7 @@ scores_list = rerank_model.compute_score(
print(scores_list)
```

**搭配Langchain构建RAG应用**
**RAG:搭配Langchain**

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1fJC-8er-a4NRkdJkwWr4On7lGt9rAO4P?usp=sharing)

Expand Down Expand Up @@ -178,13 +178,16 @@ print(response)
```


**微调文本向量模型**
**向量模型微调**

- Model performance fine-tuned in [T2Ranking](https://huggingface.co/datasets/THUIR/T2Ranking)
[//]: # (- Model performance fine-tuned in [T2Ranking](https://huggingface.co/datasets/THUIR/T2Ranking))

| Model | Size | AP<sup>val</sup> | AP<sub>50</sub><sup>val</sup> | AP<sub>75</sub><sup>val</sup> |
| :-- | :-: | :-: | :-: | :-: |
| TripletLoss | 672 | 47.7% |52.6% | 61.4% |
[//]: # ()
[//]: # (| Model | Size | AP<sup>val</sup> | AP<sub>50</sub><sup>val</sup> | AP<sub>75</sub><sup>val</sup> |)

[//]: # (| :-- | :-: | :-: | :-: | :-: |)

[//]: # (| TripletLoss | 672 | 47.7% |52.6% | 61.4% |)


[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/17KXe2lnNRID-HiVvMtzQnONiO74oGs91?usp=sharing)
Expand All @@ -206,7 +209,7 @@ tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=False)
model = AutoModelForEmbedding.from_pretrained(model_name_or_path, pooling_method="cls")
# model = model.set_train_type('pointwise') # 'pointwise', 'pairwise', 'listwise'
optimizer = AdamW(model.parameters(), lr=5e-5)
num_train_steps=int(len(train_dataset) / batch_size * epochs)
num_train_steps = int(len(train_dataset) / batch_size * epochs)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0.05 * num_train_steps, num_training_steps=num_train_steps)

training_arguments = TrainingArguments(
Expand All @@ -227,7 +230,7 @@ trainer.scheduler = scheduler
trainer.train()
```

- 一键训练
- shell训练

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1w2dRoRThG6DnUW46swqEUuWySKS1AXCp?usp=sharing)

Expand All @@ -254,7 +257,7 @@ torchrun --nproc_per_node 1 \
```


**微调重排模型**
**重排模型微调**

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1QvbUkZtG56SXomGYidwI4RQzwODQrWNm?usp=sharing)

Expand Down
16 changes: 7 additions & 9 deletions docs/source/embed.rst
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
Embedding
====================
==============================

.. _embed:



Fine-tuning text embedding
------------------------------

Expand Down Expand Up @@ -49,12 +48,11 @@ Fine-tuning text embedding
Point-wise
--------------
Point wise
--------------------------

arcface
- 分层学习率
- batch size影响大
- arcface_margin动态调整, margin大小影响较大
- arc_weight初始化
- 含状态训练的损失函数不适合每个epoch训练时也过一遍评价指标
- layer wise learning rate
- batch size is important
- dynamic arcface_margin, margin is important
- arc_weight init
8 changes: 2 additions & 6 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
@@ -1,8 +1,3 @@
.. Retrievals documentation master file, created by
sphinx-quickstart on Mon Feb 19 14:43:55 2024.
You can adapt this file completely to your liking, but it should at least
contain the root `toctree` directive.
Open-Retrievals Documentation
======================================
.. raw:: html
Expand All @@ -21,7 +16,8 @@ Installation
Install the **Prerequisites**

* transformers
* faiss-cpu / faiss-gpu
* peft
* faiss-cpu


Now you are ready, proceed with
Expand Down
102 changes: 91 additions & 11 deletions docs/source/quick-start.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,13 @@ Quick start

.. _quick-start:

Use the pretrained weight as embedding
---------------------------------------------
Open-retrievals is designed to simplify the information retrieval and RAG application, especially for retrieval, rerank.

You can use the pretrained embedding easily from transformers or sentence-transformers.

1. Embedding
-----------------------------

We can use the pretrained embedding easily from transformers or sentence-transformers.

.. code-block:: python
Expand All @@ -18,18 +21,95 @@ You can use the pretrained embedding easily from transformers or sentence-transf
sentence_embeddings = model.encode(sentences, normalize_embeddings=True, convert_to_tensor=True)
print(sentence_embeddings)
.. code::
output
Embedding fine-tuned
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

If we want to further improve the retrieval performance, an optional method is to fine tune the embedding model weights. It will project the vector of query and answer to similar representation space.

.. code-block:: python
import torch.nn as nn
from datasets import load_dataset
from transformers import AutoTokenizer, AdamW, get_linear_schedule_with_warmup, TrainingArguments
from retrievals import AutoModelForEmbedding, RetrievalTrainer, PairCollator, TripletCollator
from retrievals.losses import ArcFaceAdaptiveMarginLoss, InfoNCE, SimCSE, TripletLoss
model_name_or_path: str = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
batch_size: int = 128
epochs: int = 3
train_dataset = load_dataset('shibing624/nli_zh', 'STS-B')['train']
train_dataset = train_dataset.rename_columns({'sentence1': 'query', 'sentence2': 'document'})
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=False)
model = AutoModelForEmbedding.from_pretrained(model_name_or_path, pooling_method="cls")
optimizer = AdamW(model.parameters(), lr=5e-5)
num_train_steps = int(len(train_dataset) / batch_size * epochs)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0.05 * num_train_steps, num_training_steps=num_train_steps)
training_arguments = TrainingArguments(
output_dir='./checkpoints',
num_train_epochs=epochs,
per_device_train_batch_size=batch_size,
remove_unused_columns=False,
)
trainer = RetrievalTrainer(
model=model,
args=training_arguments,
train_dataset=train_dataset,
data_collator=PairCollator(tokenizer, query_max_length=128, document_max_length=128),
loss_fn=InfoNCE(nn.CrossEntropyLoss(label_smoothing=0.05)),
)
trainer.optimizer = optimizer
trainer.scheduler = scheduler
trainer.train()
2. Indexing
-----------------------------

Save the document embedding offline.

.. code-block:: python
from retrievals import AutoModelForEmbedding, AutoModelForRetrieval
sentences = ['A dog is chasing car.', 'A man is playing a guitar.']
model_name_or_path = "sentence-transformers/all-MiniLM-L6-v2"
index_path = './database/faiss/faiss.index'
model = AutoModelForEmbedding.from_pretrained(model_name_or_path)
model.build_index(sentences, index_path=index_path)
query_embed = model.encode("He plays guitar.")
matcher = AutoModelForRetrieval()
dists, indices = matcher.similarity_search(query_embed, index_path=index_path)
print(indices)
3. Rerank
-----------------------------

If we have multiple retrieval source or a better sequence, we can add the reranking to pipeline.

.. code-block:: python
Fine tune the transformer pretrained weight by contrastive learning
----------------------------------------------------------------------
from retrievals import AutoModelForRanking
model_name_or_path: str = "BAAI/bge-reranker-base"
rerank_model = AutoModelForRanking.from_pretrained(model_name_or_path)
scores_list = rerank_model.compute_score(["In 1974, I won the championship in Southeast Asia in my first kickboxing match", "In 1982, I defeated the heavy hitter Ryu Long."])
print(scores_list)
Query search by faiss
--------------------------
Rerank fine-tuned
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Rerank to enhance the performance
----------------------------------------

4. RAG
-----------------------------

Langchain example for RAG
--------------------------------
We can use open-retrievals easily to build RAG application, or integrated with LangChain and Llamaindex.
14 changes: 14 additions & 0 deletions docs/source/rag.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@ RAG could help solve the false information, out-of-date information, and data se
* Output reference for explainability
* LLM Hallucination


Integrated with Langchain
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~


.. code-block:: python
from retrievals.tools.langchain import LangchainEmbedding, LangchainReranker, LangchainLLM
Expand Down Expand Up @@ -73,6 +78,15 @@ RAG could help solve the false information, out-of-date information, and data se
print(response)
Integrated with Llamaindex
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~


Custom RAG
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~



Enhance RAG Performance
---------------------------

Expand Down
2 changes: 1 addition & 1 deletion docs/source/rerank.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ Rerank

.. _rerank:

Use pretrained rerank
Rerank by pretrained
-------------------------

.. code-block:: python
Expand Down
14 changes: 11 additions & 3 deletions src/retrievals/data/collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ def __init__(
tokenizer: PreTrainedTokenizer,
query_max_length: int = 32,
document_max_length: int = 128,
append_eos_token: bool = False,
query_key: str = 'query',
document_key: str = 'positive',
) -> None:
Expand Down Expand Up @@ -74,6 +75,7 @@ def __init__(
tokenizer: PreTrainedTokenizer,
query_max_length: int = 32,
document_max_length: int = 128,
append_eos_token: bool = False,
query_key: str = 'query',
positive_key: str = 'positive',
negative_key: Optional[str] = 'negative',
Expand Down Expand Up @@ -118,6 +120,8 @@ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
tokenize_fn = self.tokenizer
tokenize_args = {
"truncation": True,
"return_token_type_ids": False,
"add_special_tokens": True,
}
else:
tokenize_fn = self.tokenizer.pad
Expand Down Expand Up @@ -147,6 +151,7 @@ def __init__(
self,
tokenizer: PreTrainedTokenizer,
max_length: int = 128,
append_eos_token: bool = False,
query_key: str = 'query',
document_key: str = 'document',
):
Expand Down Expand Up @@ -225,6 +230,11 @@ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
query_texts = [feature[self.query_key] for feature in features]
pos_texts = [feature[self.positive_key] for feature in features]

if isinstance(query_texts[0], list):
query_texts = sum(query_texts, [])
if isinstance(pos_texts[0], list):
pos_texts = sum(pos_texts, [])

if isinstance(query_texts[0], str):
tokenize_fn = self.tokenizer
tokenize_args = {
Expand Down Expand Up @@ -252,6 +262,7 @@ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:

if self.negative_key in features[0]:
neg_texts = [feature[self.negative_key] for feature in features]

if isinstance(neg_texts[0], list):
neg_texts = sum(neg_texts, []) # flatten nested list

Expand All @@ -264,7 +275,4 @@ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
)
batch.update({'neg_input_ids': neg_inputs['input_ids'], 'neg_attention_mask': neg_inputs['attention_mask']})

# if 'labels' in features[0].keys():
# labels = [feature['labels'] for feature in features]
# batch['labels'] = torch.tensor(labels, dtype=torch.float32)
return batch
Loading

0 comments on commit e405abc

Please sign in to comment.