Skip to content

Commit

Permalink
fix: embed benchmark
Browse files Browse the repository at this point in the history
  • Loading branch information
LongxingTan committed Jun 7, 2024
1 parent 24bdc39 commit b73e6db
Show file tree
Hide file tree
Showing 26 changed files with 221 additions and 121 deletions.
14 changes: 2 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,19 +48,9 @@ pip install peft # if necessary
pip install open-retrievals
```

[//]: # (**With conda**)

[//]: # (```shell)

[//]: # (conda install open-retrievals -c conda-forge)

[//]: # (```)

**With source code**
```shell
git clone https://github.com/LongxingTan/open-retrievals
cd open-retrievals
pip install -e .
python -m pip install -U git+https://github.com/LongxingTan/open-retrievals.git
```


Expand Down Expand Up @@ -241,7 +231,7 @@ epochs: int = 3

train_dataset = RerankDataset('./t2rank.json', positive_key='pos', negative_key='neg')
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=False)
model = AutoModelForRanking.from_pretrained(model_name_or_path, pooling_method="mean")
model = AutoModelForRanking.from_pretrained(model_name_or_path)
optimizer = AdamW(model.parameters(), lr=learning_rate)
num_train_steps = int(len(train_dataset) / batch_size * epochs)
scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=0.05 * num_train_steps, num_training_steps=num_train_steps)
Expand Down
2 changes: 1 addition & 1 deletion README_ja-JP.md
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ epochs: int = 3

train_dataset = RerankDataset('./t2rank.json', positive_key='pos', negative_key='neg')
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=False)
model = AutoModelForRanking.from_pretrained(model_name_or_path, pooling_method="mean")
model = AutoModelForRanking.from_pretrained(model_name_or_path)
optimizer = AdamW(model.parameters(), lr=learning_rate)
num_train_steps = int(len(train_dataset) / batch_size * epochs)
scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=0.05 * num_train_steps, num_training_steps=num_train_steps)
Expand Down
36 changes: 30 additions & 6 deletions README_zh-CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,7 @@ pip install open-retrievals

**源码安装**
```shell
git clone https://github.com/LongxingTan/open-retrievals
cd open-retrievals
pip install -e .
python -m pip install -U git+https://github.com/LongxingTan/open-retrievals.git
```


Expand Down Expand Up @@ -253,14 +251,15 @@ trainer.train()

```shell
MODEL_NAME='BAAI/bge-small-zh-v1.5'
OUTPUT_DIR="/train_out"

torchrun --nproc_per_node 1 \
-m retrievals.pipelines.embed \
--output_dir train \
--output_dir $OUTPUT_DIR \
--overwrite_output_dir \
--model_name_or_path $MODEL_NAME \
--do_train \
--train_data train.jsonl \
--train_data t2_ranking.jsonl \
--learning_rate 3e-5 \
--fp16 \
--num_train_epochs 5 \
Expand Down Expand Up @@ -289,7 +288,7 @@ epochs: int = 3

train_dataset = RerankDataset('./t2rank.json', positive_key='pos', negative_key='neg')
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=False)
model = AutoModelForRanking.from_pretrained(model_name_or_path, pooling_method="mean")
model = AutoModelForRanking.from_pretrained(model_name_or_path)
optimizer = AdamW(model.parameters(), lr=learning_rate)
num_train_steps = int(len(train_dataset) / batch_size * epochs)
scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=0.05 * num_train_steps, num_training_steps=num_train_steps)
Expand All @@ -312,6 +311,31 @@ trainer.scheduler = scheduler
trainer.train()
```

```shell
MODEL_NAME="BAAI/bge-reranker-base"
OUTPUT_DIR="/train_out"

torchrun --nproc_per_node 1 \
-m retrievals.pipelines.rerank \
--output_dir $OUTPUT_DIR \
--overwrite_output_dir \
--model_name_or_path $MODEL_NAME \
--do_train \
--train_data t2_ranking.jsonl \
--positive_key positive \
--negative_key negative \
--learning_rate 3e-5 \
--fp16 \
--num_train_epochs 3 \
--per_device_train_batch_size 64 \
--dataloader_drop_last True \
--max_length 512 \
--max_negative_samples 7 \
--unfold_each_positive false \
--save_total_limit 2 \
--logging_steps 100
```


## 参考与致谢
- [sentence-transformers](https://github.com/UKPLab/sentence-transformers)
Expand Down
Empty file added examples/1_retrieval/README.md
Empty file.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# accelerate launch --config_file conf_ds.yaml \

accelerate launch \
--config_file conf/conf_llm.yaml \
--config_file conf_llm.yaml \
llm_finetune_for_embed.py \
--model_name_or_path mistralai/Mistral-7B-v0.1 \
--train_data \
Expand Down
53 changes: 53 additions & 0 deletions examples/2_rerank/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Ranking Examples

## prepare data
```python
from datasets import load_dataset

dataset = load_dataset("C-MTEB/T2Reranking", split="dev")
ds = dataset.train_test_split(test_size=0.1, seed=42)

ds_train = (
ds["train"]
.filter(lambda x: len(x["positive"]) > 0 and len(x["negative"]) > 0)
)

ds_train.to_json("t2_ranking.jsonl", force_ascii=False)
```

## train
- cross encoder
```shell
python train_cross_encoder.py
```

- cross encoder
```shell
MODEL_NAME="BAAI/bge-reranker-base"
TRAIN_DATA="/t2_ranking.jsonl"
OUTPUT_DIR="/t2_output"

torchrun --nproc_per_node 1 \
-m retrievals.pipelines.rerank \
--output_dir $OUTPUT_DIR \
--overwrite_output_dir \
--model_name_or_path $MODEL_NAME \
--do_train \
--train_data $TRAIN_DATA \
--positive_key positive \
--negative_key negative \
--learning_rate 3e-5 \
--fp16 \
--num_train_epochs 3 \
--per_device_train_batch_size 32 \
--dataloader_drop_last True \
--max_length 512 \
--max_negative_samples 2 \
--logging_steps 100
```


- colbert
```shell
python train_colbert.py
```
Empty file.
9 changes: 4 additions & 5 deletions examples/2_rerank/train_colbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from retrievals import (
AutoModelForRanking,
ColBertCollator,
RerankCollator,
RerankDataset,
RerankTrainer,
)
Expand All @@ -19,12 +18,12 @@


model_name_or_path: str = "microsoft/deberta-v3-base"
max_length: int = 128
max_length: int = 512
learning_rate: float = 3e-5
batch_size: int = 4
batch_size: int = 32
epochs: int = 3

train_dataset = RerankDataset("./t2rank_100.json", positive_key="pos", negative_key="neg")
train_dataset = RerankDataset("t2_ranking.jsonl", positive_key="positive", negative_key="negative")
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=False)
model = AutoModelForRanking.from_pretrained(
model_name_or_path,
Expand Down Expand Up @@ -57,7 +56,7 @@
tokenizer,
query_max_length=max_length,
document_max_length=max_length,
positive_key="document",
positive_key="positive",
),
)
trainer.optimizer = optimizer
Expand Down
10 changes: 5 additions & 5 deletions examples/2_rerank/train_cross_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,14 @@
transformers.logging.set_verbosity_error()

model_name_or_path: str = "microsoft/deberta-v3-base"
max_length: int = 128
max_length: int = 512
learning_rate: float = 3e-5
batch_size: int = 4
batch_size: int = 32
epochs: int = 3

train_dataset = RerankDataset("./t2rank_100.json", positive_key="pos", negative_key="neg")
train_dataset = RerankDataset("t2_ranking.jsonl", positive_key="positive", negative_key="negative")
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=False)
model = AutoModelForRanking.from_pretrained(model_name_or_path, pooling_method="mean")
model = AutoModelForRanking.from_pretrained(model_name_or_path)
optimizer = AdamW(model.parameters(), lr=learning_rate)
num_train_steps = int(len(train_dataset) / batch_size * epochs)
scheduler = get_cosine_schedule_with_warmup(
Expand All @@ -38,7 +38,7 @@
model=model,
args=training_args,
train_dataset=train_dataset,
data_collator=RerankCollator(tokenizer, query_max_length=max_length, document_max_length=max_length),
data_collator=RerankCollator(tokenizer, max_length=max_length),
)
trainer.optimizer = optimizer
trainer.scheduler = scheduler
Expand Down
Empty file added examples/3_rag/README.md
Empty file.
7 changes: 2 additions & 5 deletions src/retrievals/data/collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,18 +146,15 @@ class RerankCollator(DataCollatorWithPadding):
def __init__(
self,
tokenizer: PreTrainedTokenizer,
query_max_length: int = 32,
document_max_length: int = 128,
max_length: int = 128,
query_key: str = 'query',
document_key: str = 'document',
):
self.tokenizer = tokenizer
if not hasattr(self.tokenizer, "pad_token_id") or self.tokenizer.pad_token is None:
self.tokenizer.add_special_tokens({"pad_token": "[PAD]"})

self.query_max_length = query_max_length
self.document_max_length = document_max_length
self.max_length = query_max_length + document_max_length
self.max_length = max_length
self.query_key = query_key
self.document_key = document_key

Expand Down
31 changes: 22 additions & 9 deletions src/retrievals/data/dataset.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import copy
import logging
import math
import os
Expand Down Expand Up @@ -100,6 +99,7 @@ def __init__(
positive_key: Optional[str] = 'document',
negative_key: Optional[str] = 'negative',
max_negative_samples: Optional[int] = None,
unfold_each_positive: bool = False,
args: Optional = None,
tokenizer: PreTrainedTokenizer = None,
):
Expand All @@ -110,9 +110,16 @@ def __init__(
else:
self.max_negative_samples = max_negative_samples

self.query_key = args.query_key or query_key
self.positive_key = args.positive_key or positive_key
self.negative_key = args.negative_key or negative_key
if args:
self.query_key = args.query_key or query_key
self.positive_key = args.positive_key or positive_key
self.negative_key = args.negative_key or negative_key
self.unfold_each_positive = args.unfold_each_positive or unfold_each_positive
else:
self.query_key = query_key
self.positive_key = positive_key
self.negative_key = negative_key
self.unfold_each_positive = unfold_each_positive

if isinstance(data_name_or_path, datasets.Dataset):
dataset = data_name_or_path
Expand Down Expand Up @@ -155,13 +162,19 @@ def __getitem__(self, item: int):
def generate_samples(self, dataset):
samples: List = []
for data in dataset:
for pos_text in data[self.positive_key]:
samples.append([data[self.query_key], pos_text, 1])
if self.unfold_each_positive:
for pos_text in data[self.positive_key]:
samples.append([data[self.query_key], pos_text, 1])
else:
samples.append([data[self.query_key], random.choice(data[self.positive_key]), 1])

negative_samples = data[self.negative_key]
if self.max_negative_samples:
# TODO: random strategy
negative_samples = negative_samples[: self.max_negative_samples]
if self.max_negative_samples and self.max_negative_samples > 0:
if len(negative_samples) < self.max_negative_samples:
num = math.ceil(self.max_negative_samples / len(negative_samples))
negative_samples = random.sample(negative_samples * num, self.max_negative_samples)
else:
negative_samples = random.sample(negative_samples, self.max_negative_samples)
for neg_text in negative_samples:
samples.append([data[self.query_key], neg_text, 0])
return samples
Expand Down
27 changes: 12 additions & 15 deletions src/retrievals/losses/infonce.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ def __init__(
temperature: float = 0.05,
use_inbatch_negative: bool = True,
negative_mode: Literal['paired', 'unpaired'] = "unpaired",
train_group_size: int = 1,
):
"""
if not normalized: temperature = 1.0, reset temperature = 1.0 due to using inner product to compute similarity
Expand All @@ -36,7 +35,6 @@ def __init__(
self.temperature = temperature
self.use_inbatch_negative = use_inbatch_negative
self.negative_mode = negative_mode
self.train_group_size = train_group_size
if self.temperature > 0.5:
logger.error('InfoNCE loss use normalized and inner product by default, temperature should be 0.01 ~ 0.1')

Expand All @@ -52,15 +50,15 @@ def forward(
if negative_embeddings is None:
if self.negative_mode == 'unpaired':
logits = query_embeddings @ positive_embeddings.transpose(-2, -1)
labels = torch.arange(logits.size(0), dtype=torch.long, device=device)
loss = self.criterion(logits / self.temperature, labels)
target = torch.arange(logits.size(0), dtype=torch.long, device=device)
loss = self.criterion(logits / self.temperature, target)
else:
logits1 = query_embeddings @ positive_embeddings.transpose(-2, -1)
logits2 = logits1.T
labels = torch.arange(logits1.size(0), dtype=torch.long, device=device)
target = torch.arange(logits1.size(0), dtype=torch.long, device=device)
loss = (
self.criterion(logits1 / self.temperature, labels)
+ self.criterion(logits2 / self.temperature, labels)
self.criterion(logits1 / self.temperature, target)
+ self.criterion(logits2 / self.temperature, target)
) / 2
return loss
else:
Expand All @@ -70,13 +68,12 @@ def forward(
similarity = query_embeddings @ logits.transpose(-2, -1)
similarity = similarity / self.temperature
similarity = similarity.view(query_embeddings.size(0), -1)
labels = torch.arange(query_embeddings.size(0), dtype=torch.long, device=device)
target = torch.arange(query_embeddings.size(0), dtype=torch.long, device=device)
else:
logits = torch.cat([positive_embeddings, negative_embeddings], dim=0)
logits = logits.view(query_embeddings.size(0), -1, self.train_group_size)
similarity = query_embeddings.unsqueeze(1) @ logits
similarity = similarity.squeeze(1) / self.temperature
similarity = similarity.view(query_embeddings.size(0), -1)
labels = torch.zeros(logits.size(0), dtype=torch.long, device=device)
similarity = query_embeddings.unsqueeze(1) @ positive_embeddings.unsqueeze(2)
negative_similarity = query_embeddings.unsqueeze(1) @ negative_embeddings.unsqueeze(2)
similarity = torch.cat([similarity.squeeze(1), negative_similarity.squeeze(1)], dim=1)
similarity = similarity / self.temperature
target = torch.zeros(query_embeddings.size(0), dtype=torch.long, device=device)

return self.criterion(similarity, labels)
return self.criterion(similarity, target)
1 change: 0 additions & 1 deletion src/retrievals/losses/margin_mse.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import torch
from torch import Tensor, nn
from torch.nn import functional as F


class MarginMSELoss(nn.Module):
Expand Down
Loading

0 comments on commit b73e6db

Please sign in to comment.