-
Notifications
You must be signed in to change notification settings - Fork 3.1k
Add Contrastive Learning #10097
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
ZHUI
merged 4 commits into
PaddlePaddle:develop
from
jie-z-0607:add_contrastive_learning
Mar 18, 2025
Merged
Add Contrastive Learning #10097
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,108 @@ | ||
| # Contrastive Learning (CL) | ||
|
|
||
| 对比学习(Contrastive Learning)是一种自监督学习的方法,旨在通过比较样本之间的相似性和差异性来学习数据的有效表示。在对比学习中,模型通常被训练以最大化相似样本对(正样本对)的相似性,同时最小化非相似样本对(负样本对)的相似性。这种方法不需要明确的标签信息,因此能够利用大量未标注的数据进行训练。 | ||
|
|
||
| 对比学习的优势在于其能够有效利用未标注数据,减少对大规模标注数据的依赖,并且通常能在下游任务中获得强大的泛化能力。这种方法在文本、图像以及其他领域的数据表示学习中都表现出了优异的性能。 | ||
|
|
||
| ## 1.数据准备 | ||
| 以文本向量模型(Embedding Model)的对比学习为例,需要准备的数据样例如下: | ||
| ``` | ||
| {"query": "四季青的炮制方法是什么?", | ||
| "pos_passage": ["取原药材,除去残枝、枯叶及杂质,略润,切成丝,干燥,筛去灰屑。饮片性状:为大小、长短不一的丝状,革质。上表面光滑有光泽,灰绿色或暗褐色,下表面色较浅,主脉微隆。气微清香,味苦、微涩。贮干燥容器内,置阴凉干燥处。"], | ||
| "neg_passage": ['平时多注意锻炼。饮食方面多吃大叶的绿色蔬菜,肉类食用一些白肉,比如鸡肉和鱼肉,水果可以吃一些含果胶多的,比如苹果、桃子、橙子等。']} | ||
| ``` | ||
| **注释**: | ||
| - query : 查询文本 | ||
| - pos_passage : 查询文本对应的正样本列表 | ||
| - neg_passage : 查询文本对应的负样本列表 | ||
|
|
||
| ### 1.1 Query 清洗 | ||
| Embedding Model 进行对比学习时,训练效果高度依赖于数据的质量。如果数据集中存在大量相似的 query,模型可能会陷入过度关注这些样本的误区,忽视其他关键特征,进而干扰训练过程,削弱最终效果。因此,在进行对比学习之前,对 query 进行清洗,去除相似的 query,是提升模型性能的关键步骤。 | ||
|
|
||
| 我们设计了以下步骤,以高效完成相似 query 的清洗任务: | ||
| - **构建向量表示**。利用 Embedding Model 将每个 query 转换为向量表示。 | ||
| - **计算文本相似度**。利用余弦相似度计算 query 之间的相似度,并设置相似度阈值。如果 query 之间的相似度超过阈值,则认为它们是相似的,需进行去重处理。 | ||
| - **多步骤加速**。用了多卡推理技术,充分利用多 GPU 的并行计算能力加速向量表示的构建过程。同时,结合 faiss 库构建高效的向量索引,并通过 GPU 进一步加速相似度计算与 query 召回过程。 | ||
|
|
||
| Query 清洗的示例如下: | ||
| ``` | ||
| from clean_query import Clean_Query | ||
|
|
||
| model_path = 'BAAI/bge-m3' | ||
| tokenizer_path = 'BAAI/bge-m3' | ||
| input_data_path = './toy_data/toy_source.json' | ||
| output_data_path='./toy_data/test_clean.json' | ||
| test_clean = Clean_Query(model_path, tokenizer_path, input_data_path=input_data_path, output_data_path=output_data_path, similarity_threshold=0.70) | ||
| test_clean.clean() | ||
| ``` | ||
|
|
||
| 可以通过多卡推理有效提高清洗效率,示例如下: | ||
| ``` | ||
| python -u -m paddle.distributed.launch --devices "0,1,2,3,4,5,6,7" clean_query.py | ||
| ``` | ||
|
|
||
| ### 1.2 负样本挖掘 | ||
| 在对比学习框架中,负样本的质量与数量对模型训练效率及收敛速度具有显著影响。高质量的负样本能够有效提升模型区分正负样本的能力,从而加速训练过程并增强模型的泛化性能。然而,从海量数据中高效挖掘出真正具有挑战性的负样本,是一项既复杂又关键的任务。 | ||
|
|
||
| 为此,我们设计了以下步骤,旨在快速精准地从数据集中筛选出有价值的负样本: | ||
|
|
||
| - **构建向量表示**。利用 Embedding Model 将 query 与 positive passage 转换为向量表示。 | ||
| - **负样本识别**。在向量空间中,计算 query 与候选样本之间的余弦相似度,召回识别出那些虽与 query 不直接相关但具有一定相似性的样本作为负样本。这类样本能够促使模型在训练过程中学习到更细腻的特征区分能力。 | ||
| - **多步骤加速**。为提升负样本挖掘的效率,我们采用了多卡推理技术,充分利用多 GPU 的并行计算能力加速向量表示的构建过程。同时,结合 faiss 库构建高效的向量索引,并通过 GPU 进一步加速负样本的召回与筛选过程。 | ||
|
|
||
| 负样本挖掘的示例如下: | ||
| ``` | ||
| from mining_negative_samples import MiningNegativeSamples | ||
|
|
||
| input_data_path='./toy_data/toy_source.json' | ||
| output_data_path='./toy_data/test_min_neg.json' | ||
| model_path = 'BAAI/bge-m3' | ||
| tokenizer_path = 'BAAI/bge-m3' | ||
| test_mining = MiningNegativeSamples(model_path, tokenizer_path, input_data_path=input_data_path, output_data_path=output_data_path) | ||
| test_mining.mining() | ||
| ``` | ||
|
|
||
| 可以通过多卡推理有效提高挖掘效率,示例如下: | ||
| ``` | ||
| python -u -m paddle.distributed.launch --devices "0,1,2,3,4,5,6,7" mining_negative_samples.py | ||
| ``` | ||
|
|
||
| ## 2.训练 | ||
| Embedding Model 训练代码位置详见: | ||
| - [run_embedding.py](../../../llm/run_embedding.py) | ||
|
|
||
| Embedding Model 训练示例如下: | ||
| ``` | ||
| python -u -m paddle.distributed.launch --devices "0,1,2,3,4,5,6,7" run_embedding.py ./config/xlm_roberta/emb_argument.json | ||
| ``` | ||
|
|
||
| ## 3.推理评估 | ||
| 训练完成后,可以对 Embedding Model 进行推理评估,评估指标包括:hit rate,MRR,NDCG 等。示例代码如下: | ||
| ``` | ||
| model_path = 'BAAI/bge-m3' | ||
| tokenizer_path = 'BAAI/bge-m3' | ||
| query_pos_passage_path = './toy_data/toy_dev.json' | ||
| neg_passage_path = './toy_data/toy_dev_neg.json' | ||
| eval = Embedding_Evaluation(model_path, tokenizer_path, query_pos_passage_path, neg_passage_path) | ||
| print(eval.evaluate()) | ||
| ``` | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个可以评估 mteb 的话,可以加一段如何评估的。加一个得分 |
||
|
|
||
| 多卡推理评估的示例如下: | ||
| ``` | ||
| python -u -m paddle.distributed.launch --devices "0,1,2,3,4,5,6,7" embedding_evaluate.py | ||
| ``` | ||
| **注释**: | ||
| - 其中 query_pos_passage_path 为需要评估的查询文本(query)-正样本(positive passage)对,示例如下: | ||
| ``` | ||
| {"query": "四季青的炮制方法是什么?", | ||
| "pos_passage": ["取原药材,除去残枝、枯叶及杂质,略润,切成丝,干燥,筛去灰屑。饮片性状:为大小、长短不一的丝状,革质。上表面光滑有光泽,灰绿色或暗褐色,下表面色较浅,主脉微隆。气微清香,味苦、微涩。贮干燥容器内,置阴凉干燥处。"]} | ||
| ``` | ||
| - neg_passage_path 为需要加入评估的负样本(negative passage)数据,示例如下: | ||
| ``` | ||
| {"neg_passage": ['平时多注意锻炼。饮食方面多吃大叶的绿色蔬菜,肉类食用一些白肉,比如鸡肉和鱼肉,水果可以吃一些含果胶多的,比如苹果、桃子、橙子等。']} | ||
| ``` | ||
| - 推理评估后将打印并返回 Hit_rate、MRR、NDCG 等评价指标。 | ||
|
|
||
| ## Acknowledge | ||
|
|
||
| 我们借鉴了[FlagEmbedding/Tutorials/7_Fine-tuning](https://github.com/FlagOpen/FlagEmbedding/blob/master/Tutorials/7_Fine-tuning/7.2.1_Hard_Negative_Mining.ipynb)中有关挖掘强负样本的代码设计,在此对其作者表示感谢。 | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,256 @@ | ||
| # Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import json | ||
|
|
||
| import faiss | ||
| import numpy as np | ||
| import paddle | ||
| import paddle.distributed as dist | ||
| from tqdm import tqdm | ||
|
|
||
| from paddlenlp.transformers import AutoConfig, AutoModel, AutoTokenizer | ||
|
|
||
|
|
||
| class Clean_Query: | ||
| def __init__( | ||
| self, | ||
| model_path, | ||
| tokenizer_path, | ||
| input_data_path, | ||
| output_data_path, | ||
| template="{text}", | ||
| dimension=1024, | ||
| max_src_len=8192, | ||
| normalize=True, | ||
| dtype=None, | ||
| similarity_threshold=0.75, | ||
| ): | ||
| # Initialize the tokenizer | ||
| self.tokenizer = AutoTokenizer.from_pretrained( | ||
| tokenizer_path, | ||
| padding_side="right", | ||
| truncation_side="right", | ||
| ) | ||
|
|
||
| self.config = AutoConfig.from_pretrained(model_path) | ||
| self.config.embedding_negatives_cross_device = False | ||
| self.dtype = dtype if dtype else self.config.dtype | ||
|
|
||
| # Initialize the distributed environment | ||
| dist.init_parallel_env() | ||
| world_size = dist.get_world_size() | ||
| if world_size > 1: | ||
| print(f"Running in multi-GPU mode with {world_size} GPUs.") | ||
| else: | ||
| print("Running in single-GPU or CPU mode.") | ||
|
|
||
| # Initialize the embedding model | ||
| self.model = AutoModel.from_pretrained( | ||
| model_path, config=self.config, dtype=self.dtype, low_cpu_mem_usage=False | ||
| ) | ||
| self.model.eval() | ||
|
|
||
| self.input_data_path = input_data_path | ||
| self.output_data_path = output_data_path | ||
| self.template = template | ||
| self.dimension = dimension | ||
| self.max_src_len = max_src_len | ||
| self.normalize = normalize | ||
| self.similarity_threshold = similarity_threshold | ||
|
|
||
| def _preprocess(self, texts): | ||
| """Pre-process inputs.""" | ||
| template_prefix, template_suffix = self.template.split("{text}") | ||
| prefix_tokens = self.tokenizer(template_prefix, add_special_tokens=False).input_ids | ||
| suffix_tokens = self.tokenizer(template_suffix, add_special_tokens=False).input_ids | ||
|
|
||
| # If the template does not contain a suffix token, add the EOS token | ||
| if len(suffix_tokens) == 0: | ||
| suffix_tokens = [self.tokenizer.eos_token_id] | ||
| # If the template does not contain a prefix token, add the BOS token | ||
| if len(prefix_tokens) == 0: | ||
| prefix_tokens = [self.tokenizer.bos_token_id] | ||
|
|
||
| available_len = self.max_src_len - len(prefix_tokens) - len(suffix_tokens) | ||
| truncated_token_ids = self._batch_truncate_and_tokenize(texts, available_len) | ||
| complete_token_ids = [prefix_tokens + tid + suffix_tokens for tid in truncated_token_ids] | ||
| position_ids = [list(range(len(cid))) for cid in complete_token_ids] | ||
| max_len = max([len(cid) for cid in complete_token_ids]) | ||
| embedding_indices = [[idx, len(cid) - 1] for idx, cid in enumerate(complete_token_ids)] | ||
|
|
||
| inputs = self.tokenizer.pad( | ||
| { | ||
| "input_ids": complete_token_ids, | ||
| "position_ids": position_ids, | ||
| "embedding_indices": embedding_indices, | ||
| }, | ||
| padding="max_length", | ||
| return_attention_mask=True, | ||
| max_length=max_len, | ||
| return_tensors="pd", | ||
| ) | ||
| return inputs | ||
|
|
||
| def _batch_truncate_and_tokenize(self, texts, available_len): | ||
| """Tokenize the batch of texts.""" | ||
| batch_tokenized = self.tokenizer( | ||
| texts, add_special_tokens=False, padding=False, truncation=True, max_length=available_len | ||
| ) | ||
|
|
||
| truncated_token_ids = [token_ids for token_ids in batch_tokenized["input_ids"]] | ||
| return truncated_token_ids | ||
|
|
||
| def _forward(self, inputs, dimension): | ||
| """Run model forward.""" | ||
| input_type = type(inputs["input_ids"]) | ||
| outputs = self.model(**inputs) | ||
| if isinstance(outputs, input_type): | ||
| hidden_states = outputs | ||
| else: | ||
| hidden_states = outputs[0] | ||
| last_hidden_state = hidden_states[:, 0] | ||
|
|
||
| if dimension > self.config.hidden_size: | ||
| raise ValueError( | ||
| f"Dimension ({dimension}) cannot be greater than hidden_size ({self.config.hidden_size})." | ||
| ) | ||
| elif dimension != self.config.hidden_size: | ||
| last_hidden_state = last_hidden_state[:, :dimension] | ||
|
|
||
| if self.normalize: | ||
| last_hidden_state = paddle.nn.functional.normalize(last_hidden_state, axis=-1) | ||
|
|
||
| last_hidden_state = last_hidden_state.astype("float16").tolist() | ||
| return last_hidden_state | ||
|
|
||
| @paddle.no_grad() | ||
| def get_embedding(self, texts, dimension=None): | ||
| """Get inference sequence.""" | ||
| if dimension is None: | ||
| dimension = self.dimension | ||
| inputs = self._preprocess(texts) | ||
| if self.config.model_type in ["xlm-roberta"]: | ||
| del inputs["embedding_indices"] | ||
| del inputs["position_ids"] | ||
| outputs = self._forward(inputs, dimension) | ||
| return outputs | ||
|
|
||
| def clean(self): | ||
| data_list = [] | ||
| with open(self.input_data_path, "r") as f: | ||
| for line in tqdm(f): | ||
| data_list.append(json.loads(line)) | ||
|
|
||
| query_list = [single_data["query"] for single_data in data_list] | ||
|
|
||
| world_size = paddle.distributed.get_world_size() | ||
| rank = paddle.distributed.get_rank() | ||
| chunk_size = len(query_list) // world_size | ||
| if rank == world_size - 1: | ||
| # The last process handles the remaining data | ||
| query_data_chunk = query_list[rank * chunk_size :] | ||
| else: | ||
| query_data_chunk = query_list[rank * chunk_size : (rank + 1) * chunk_size] | ||
|
|
||
| batch_size = 4 # Adjust batch size according to your hardware and needs | ||
| local_q_vecs = [] | ||
|
|
||
| # Use tqdm to iterate over query_data_chunk and get embeddings in batches | ||
| for batch in tqdm(range(0, len(query_data_chunk), batch_size), desc="Processing query embeddings"): | ||
| batch_start = batch | ||
| batch_end = min(batch_start + batch_size, len(query_data_chunk)) | ||
| batch_texts = query_data_chunk[batch_start:batch_end] | ||
|
|
||
| # Call get_embedding to obtain embeddings for the current batch | ||
| batch_embeddings = self.get_embedding(batch_texts) | ||
| local_q_vecs.extend(batch_embeddings) | ||
|
|
||
| local_q_vecs_file = f"local_q_vecs_rank_{rank}.npy" | ||
| np.save(local_q_vecs_file, local_q_vecs) | ||
| dist.barrier() # Ensure all cards have reached this point before continuing | ||
|
|
||
| if rank == 0: | ||
| all_q_vecs_list = [] | ||
| world_size = paddle.distributed.get_world_size() | ||
|
|
||
| for i in range(world_size): | ||
| local_q_vecs_file = f"local_q_vecs_rank_{i}.npy" | ||
|
|
||
| # Load the embedding vector file from each process | ||
| local_q_vecs = np.load(local_q_vecs_file) | ||
| all_q_vecs_list.append(local_q_vecs) | ||
|
|
||
| all_q_vecs = [] | ||
| for q_vecs in all_q_vecs_list: | ||
| all_q_vecs.extend(q_vecs) | ||
| q_vecs = np.asarray(all_q_vecs, dtype=np.float32) | ||
|
|
||
| index = faiss.IndexFlatIP(self.dimension) | ||
| if paddle.is_compiled_with_cuda(): | ||
| co = faiss.GpuMultipleClonerOptions() | ||
| co.shard = False | ||
| co.useFloat16 = False | ||
| index = faiss.index_cpu_to_all_gpus(index, co=co) | ||
|
|
||
| temp_query_embedding = q_vecs[0].reshape(1, -1) | ||
| # faiss.normalize_L2(temp_query_embedding) | ||
| # print(q_vecs.shape) | ||
| # print(temp_query_embedding) | ||
| # print(temp_query_embedding.shape) | ||
| index.add(temp_query_embedding) | ||
|
|
||
| clean_data_list = [data_list[0]] | ||
| temp_query_list = [query_list[0]] | ||
| for i in tqdm(range(1, len(query_list))): | ||
| single_query_embedding = q_vecs[i].reshape(1, -1) | ||
| # faiss.normalize_L2(single_query_embedding) | ||
|
|
||
| if i < 3: | ||
| top_values, top_indices = index.search(single_query_embedding, 1) | ||
| else: | ||
| top_values, top_indices = index.search(single_query_embedding, 3) | ||
|
|
||
| if top_values[0][0] < self.similarity_threshold: | ||
| clean_data_list.append(data_list[i]) | ||
| index.add(single_query_embedding) | ||
| temp_query_list.append(query_list[i]) | ||
| # else: | ||
| # print(query_list[i]) | ||
| # for j in range(top_values.shape[1]): | ||
| # print(f"similarity:{top_values[0][j]} query:{temp_query_list[top_indices[0][j]]}") | ||
| # print('********************************') | ||
| # continue | ||
| # if i%10000==0: | ||
| # print(len(clean_data_list)) | ||
|
|
||
| with open(self.output_data_path, "w", encoding="utf-8") as f: | ||
| for data in clean_data_list: | ||
| f.write(json.dumps(data, ensure_ascii=False)) | ||
| f.write("\n") | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| model_path = "BAAI/bge-m3" | ||
| tokenizer_path = "BAAI/bge-m3" | ||
| input_data_path = "./toy_data/toy_source.json" | ||
| output_data_path = "./toy_data/test_clean.json" | ||
| test_clean = Clean_Query( | ||
| model_path, | ||
| tokenizer_path, | ||
| input_data_path=input_data_path, | ||
| output_data_path=output_data_path, | ||
| similarity_threshold=0.70, | ||
| ) | ||
| test_clean.clean() |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
可以讲一下原理。