Skip to content

Update negative sampling in hn_mine to fix issue #464#470

Merged
staoxiao merged 3 commits intoFlagOpen:masterfrom
shtdbb:master
Apr 24, 2024
Merged

Update negative sampling in hn_mine to fix issue #464#470
staoxiao merged 3 commits intoFlagOpen:masterfrom
shtdbb:master

Conversation

@shtdbb
Copy link
Copy Markdown
Contributor

@shtdbb shtdbb commented Feb 20, 2024

关于解决 #464 的修改。
避免困难样本挖掘时,当召回负样本数量少于预设负采样数量,会随机采样到正样本、或重复采样负样本的问题。
修改为,默认从 corpus 中剔除正例和已召回的负例,再进行随机采样;若剔除后 corpus 为空,说明需要重复采样负样本才能满足负采样数量要求,则只剔除正样本、重复采样负样本即可。

@staoxiao
Copy link
Copy Markdown
Collaborator

感谢您的PR!但是目前的操作看起来比较复杂,可能会导致比较大的时间消耗。这块可能还需要再好好考虑一下。

@shtdbb
Copy link
Copy Markdown
Contributor Author

shtdbb commented Apr 24, 2024

感谢您的PR!但是目前的操作看起来比较复杂,可能会导致比较大的时间消耗。这块可能还需要再好好考虑一下。

@staoxiao 您好,我这边修改了一下策略。

方法

我是选择多随机采样一个负样本用于备用,然后采样后查看正样本是否被采样,若正样本在则用备份样本代替,否则直接舍弃备用样本即可。这样每个样本采样后再进行过滤,只需要遍历负样本数量的列表即可。

测试

我这边使用数据集大致测试了一下:使用 10 万个样本的数据集进行负样本挖掘,设置脚本参数
python -m hn_mine --model_name_or_path models/bge-large-zh-v1.5 --input_file dataset.jsonl --range_for_sampling 2-200 --negative_number 100 --output_file dataset_neg.jsonl,设置采样 100 个负样本测试极端情况。计时则单独计算 87-98 行这个 for 循环的运算时间:

for i, data in enumerate(train_data):
query = data['query']
inxs = all_inxs[i][sample_range[0]:sample_range[1]]
filtered_inx = []
for inx in inxs:
if inx == -1: break
if corpus[inx] not in data['pos'] and corpus[inx] != query:
filtered_inx.append(inx)
if len(filtered_inx) > negative_number:
filtered_inx = random.sample(filtered_inx, negative_number)
data['neg'] = [corpus[inx] for inx in filtered_inx]

结果

  • 不做正样本过滤处理:约 8.7s
  • 正样本的后过滤处理:约 9.3s
    从结果上来看,从 10 万个样本中负采样 100 个样本,应该能满足大部分微调的负采样需求,故做正样本的后过滤大致的时间花费个人认为是可以接受的~

@staoxiao
Copy link
Copy Markdown
Collaborator

@shtdbb , 非常感谢您的PR!
有个小问题,data['pos']是一个列表,可能包含多个正样本,无法执行sent != data['pos']
如果您跑通了这个代码,需要检查数据格式是否正确。data['pos']如果是一个字符串的话,训练会有很大问题(代码将随机选取一个字母作为pos)。

代码建议改为这样:

samples = random.sample(corpus, negative_number - len(data['neg']) + len(data['pos']))
samples = [sent for sent in samples if sent not in data['pos']]
data['neg'].extend(samples[:negative_number - len(data['neg'])])

@shtdbb
Copy link
Copy Markdown
Contributor Author

shtdbb commented Apr 24, 2024

@shtdbb , 非常感谢您的PR! 有个小问题,data['pos']是一个列表,可能包含多个正样本,无法执行sent != data['pos']。 如果您跑通了这个代码,需要检查数据格式是否正确。data['pos']如果是一个字符串的话,训练会有很大问题(代码将随机选取一个字母作为pos)。

代码建议改为这样:

samples = random.sample(corpus, negative_number - len(data['neg']) + len(data['pos']))
samples = [sent for sent in samples if sent not in data['pos']]
data['neg'].extend(samples[:negative_number - len(data['neg'])])

非常感谢您的耐心!很不好意思,我忽略了data['pos']类型是list了。感谢提醒!已按照您的建议修改~

@staoxiao
Copy link
Copy Markdown
Collaborator

thanks~

@staoxiao staoxiao merged commit 502c2f2 into FlagOpen:master Apr 24, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants