Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add data augmentation strategy #2805

Merged
merged 5 commits into from
Jul 28, 2022

Conversation

lugimzzz
Copy link
Contributor

@lugimzzz lugimzzz commented Jul 14, 2022

PR types

Others

PR changes

APIs

Description

新增基于词表的词级别替换(基于同义词、同音词、随机词、本地词表、组合词表)、删除(随机)、插入(基于同义词、同音词、本地词表、随机词、组合词表)、交换(随机)的数据增强策略。

模型功能:

from paddlenlp.data_augmentation.word import WordSubstitute, WordDelete, WordSwap, WordInsert
s1 = '2021年,我再看深度学习领域,无论是自然语言处理、音频信号处理、图像处理、推荐系统,似乎都看到attention混得风生水起,只不过更多时候看到的是它的另一个代号:Transformer。'
s2 = '绝对准确率计算的是完全预测正确的样本占总样本数的比例,而0-1损失计算的是完全预测错误的样本占总样本的比例。'

# create_n:选择数据增强句子数量;aug_n:选择替换单词数量
# 同义词替换
aug = WordSubstitute('synonym', create_n=2, aug_n=2)
augmented = aug.augment(s1)
print(augmented)
# ['2021年,我再看深度习世界,无论是自然语言处理、音频信号处理、图像处理、推荐系统,似乎都看到attention混得风生水起,只不过更多时候看到的是它的另一个代号:Transformer。', '2021年,我再看吃水学习领域,无论是自然语言处理、音频信号处理、图像处理、推荐系统,似乎都看到attention混得风生水起,只不过更多时候看到的是它的另一个调号:Transformer。']

# 支持输入list
augmented = aug.augment([s1,s2])
print(augmented)
# [['2021年,我再看深度学习领域,无论自然语言处理、音频信号处理、图像处理、推荐系统,似乎都看到attention混得风生水起,只不过更多时候看到的是它的另一个商标:Transformer。', '2021年,我再看深度学习领域,听由自然语言处理、音频信号处理、图像处理、引进系统,似乎都看到attention混得风生水起,只不过更多时候看到的是它的另一个代号:Transformer。'], ['绝对准确率乘除的是完全预测正确的样本占总样本数的比例,而0-1损失算计的是完全预测错误的样本占总样本的比例。', '绝对准确率计算的是完全预测正确的样本占总样本数的比重,而0-1损失算计的是完全预测错误的样本占总样本的比例。']]

# aug_percent:选择替换单词数量百分比
aug = WordSubstitute('synonym', create_n=2, aug_percent=0.1)
augmented = aug.augment(s1)
print(augmented)
# ['2021年景,我再看深度读书领域,管自是语言处理、板眼信号处理、图像处理、推荐系统,似乎都看到attention混得风生水起,只不过更多时候看到的是它的另一个商标:Transformer。', '2021年成,我再看深度学习园地,无论是自然言语处理、点子信号处理、图像处理、举荐系统,似乎都看到attention混得风生水起,只不过更多时候看到的是它的另一个调号:Transformer。']

# 同音词替换
aug = WordSubstitute('homonym', create_n=2, aug_n=2)
augmented = aug.augment(s1)
print(augmented)
# ['2021年,我再看深度学习领域,无论是自然语言处理、音频新好处理、图像处理、推荐系统,似乎都看到attention混得风生水起,只不过更多时候看到的是它的另一个带好:Transformer。', '2021年,我再看深度学习领域,无论是自然语言处理、音频信号处理、图像处理、推荐系统,似乎都坎到attention混得风生水起,只不过更多时候看到的是它的另一个戴皓:Transformer。']


# 随机词替换
aug = WordSubstitute('random', create_n=2, aug_n=2)
augmented = aug.augment(s1)
print(augmented)
# ['原和年,我再看深度学习领域,无论是自然语言处理、音频信号处理、图像处理、推荐系统,似乎都看到attention混得风生水起,只不过更多时候看到的是它的另一个代号:任总ansformer。', '责权年,我再看深度学习领域,无论是自然语言处理、音频信号处理、图像处理、推荐系统,似乎都看到attention混得风生水起,只不过更多时候看到的是它的另一个代号:5.39ansformer。']

# 本地词表替换
aug = WordSubstitute('custom', custom_file_path='data', create_n=2, aug_n=2)
augmented = aug.augment(s1)
print(augmented)

# 组合词表替换,组合词表不支持随机词
aug = WordSubstitute(['homonym',  'synonym'], custom_file_path='data', create_n=2, aug_n=2)
augmented = aug.augment(s1)
print(augmented)
# ['2021年,我再看深度学习领域,无论是本语言处理、音频信号处理、图像处理、推荐系统,似乎都看到attention混得风生水起,只不过更多时候看到的是它的另一个待好:Transformer。', '2021年,我再看深度学习灵玉,无论是自然语言处理、音频信号处理、图像处理、推荐系统,似乎都看到attention混得风生水起,只不过更多时候看到的是它的另一个商标:Transformer。']

# 随机删除
aug = WordDelete(create_n=2, aug_n=2)
augmented = aug.augment(s1)
print(augmented)

# 随机交换
aug = WordSwap(create_n=2, aug_n=2)
augmented = aug.augment(s1)
print(augmented)

# 同义词交换,此外与替换一样,还支持同音词、本地词表、随机词、组合词表多种方式
aug = WordInsert('synonym', create_n=2, aug_n=2)
augmented = aug.augment(s1)
print(augmented)

@lugimzzz lugimzzz requested a review from wawltor July 14, 2022 12:46
@lugimzzz lugimzzz self-assigned this Jul 14, 2022
# See the License for the specific language governing permissions and
# limitations under the License.

from .base_augment import *
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

整体缺少一个使用的说明文档,可以在整体工作完备之后出一个文档

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的,开发完成后会补充文档和使用教程

("word_homonym.json", "a578c04201a697e738f6a1ad555787d5",
"https://bj.bcebos.com/paddlenlp/data/word_homonym.json")
}
self.stop_words = self._get_data('stop_words')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这些变量定义和函数可以统一标准,如果是内置变量,不希望被访问,可以 _ 开头变成一个半私有变量

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

将不希望访问的变量改为单下划线开头,如self.DATA->self._DATA

'''Calculate number of words for data augmentation'''
if size == 0:
return 0
aug_percent = self.aug_percent or 0.02
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这块是不是在类初始化的时候aug_percent默认就是0.02

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改默认初始化0.02

aug_percent=None,
aug_min=1,
aug_max=10):
paddle.set_device("cpu")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里没有太明白为什么paddle.set_device('cpu')

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已删除

'''Read data as list '''
fullname = self._load_file(mode)
data = []
with open(fullname, 'r', encoding='utf-8') as f:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里要不要验证一下这个file是否存在

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已添加if os.path.exists(fullname):


def _generate_random_index(self, seq_tokens, skip=True):
'''Random sample words for insertion/deletion/swap'''
# skip stopping words
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

skip -> Skip

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

aug_n = min(aug_n, len(indexes))
return random.sample(indexes, aug_n)

def augment(self, sequences, num_thread=1):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个函数是一个对外public函数,需要把函数和参数写清楚

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已加上

for sequence in sequences:
output.append(self._augment(sequence))
return output
# TO BE DONE: Multi Thread
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

多进程确实要考虑一下,目前windows的多进程需要多测试一下,有些坑 https://segmentfault.com/a/1190000013681586

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to be done

# limitations under the License.
import random

from paddlenlp.data_augmentation import BaseAugment
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果是library里面,建议import关系,直接import 类似这种 ..data_augmentation import BaseAugment

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改为from ..data_augmentation import BaseAugment

return indexes


if __name__ == '__main__':
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

library里面的函数不建议加main函数

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已删除

seq_tokens = self.tokenizer.cut(sequence)
aug_n = self._get_aug_n(len(seq_tokens))
aug_indexes = self.skip_words(seq_tokens)
aug_n = min(aug_n, len(aug_indexes))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这块需要讨论一下,如果字符串skip words之后,剩下的字符过少时,是不是就不用skip了

Copy link
Contributor Author

@lugimzzz lugimzzz Jul 19, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

加入策略,被增强的词数量aug_n不得大于len(aug_indexes)*0.3,也即至少每四个词才能有一个词使用数据增强策略

def _get_aug_n(self, size, size_a=None):
    if size == 0:
        return 0
    aug_n = self.aug_n or int(math.ceil(self.aug_percent * size))
    if self.aug_min and aug_n < self.aug_min:
        aug_n = self.aug_min
    elif self.aug_max and aug_n > self.aug_max:
        aug_n = self.aug_max
    if size_a is not None:
        aug_n = min(aug_n, int(math.floor(size_a*0.3)))
    return aug_n

fullname = self.custom_file_path
elif source_type in ['delete']:
fullname = self.delete_file_path
with open(fullname, 'r', encoding='utf-8') as f:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

判断一下文件是否存在吧

Copy link
Contributor Author

@lugimzzz lugimzzz Jul 19, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已加if os.path.exists(fullname):



if __name__ == '__main__':
aug = WordInsert(aug_type='synonym', create_n=2, aug_n=1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

main函数去掉

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

去掉了

idxes = random.sample(list(range(len(candidate_tokens))),
aug_n)
aug_tokens = []
for idx in idxes:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这块不整体sample一次,单次一次次sample,random.sample这块的耗时可能是一个性能瓶颈

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

random对速度的影响,开发完成后进行数据测评再选择优化方案

aug_indexes = random.sample(aug_indexes, aug_n)
for aug_index in aug_indexes:
token = self.vocab.to_tokens(
random.randint(0,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

self._reverse_sequence(seq_tokens.copy(), [aug_token]))
return sentences

def _reverse_sequence(self, output_seq_tokens, aug_tokens):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个函数名字为啥叫reverse了,好像和反转没有关系

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

函数改名为_generate_sequence

aug = WordSwap(create_n=2, aug_n=1)
s1 = '2021年,我再看深度学习领域,无论是自然语言处理、音频信号处理、图像处理、推荐系统,似乎都看到attention混得风生水起,只不过更多时候看到的是它的另一个代号:Transformer。'

augmented = aug.augment(s1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已删除

if source_type in ['synonym', 'homonym']:
fullname = self._load_file('word_' + source_type)
elif source_type in ['custom']:
fullname = self.custom_file_path
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已加if os.path.exists(fullname):

elif source_type in ['delete']:
fullname = self.delete_file_path
with open(fullname, 'r', encoding='utf-8') as f:
substitue_dict = json.load(f)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

判断一下file是否存在

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已加if os.path.exists(fullname):

aug_indexes = random.sample(aug_indexes, aug_n)
for aug_index in aug_indexes:
token = self.vocab.to_tokens(
random.randint(0,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

import paddle
from paddle.dataset.common import md5file
from paddle.utils.download import get_path_from_url
from paddlenlp.utils.env import DATA_HOME
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里paddlenlp的内容,看要不要用相对路径导入

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

from ..utils.env import DATA_HOME
from ..data import Vocab, JiebaTokenizer

aug_percent=None,
aug_min=1,
aug_max=10):
paddle.set_device("cpu")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

注意状态恢复,目前dataloader中,是默认 cpu 环境

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已删除



if __name__ == '__main__':
aug = WordDelete(create_n=10, aug_n=1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同泽阳,可以考虑放单测中

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

Copy link
Collaborator

@wawltor wawltor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@wawltor wawltor merged commit 842954c into PaddlePaddle:develop Jul 28, 2022
@lugimzzz lugimzzz deleted the data_augmentation branch July 28, 2022 09:07
@lugimzzz lugimzzz restored the data_augmentation branch July 28, 2022 09:07
@lugimzzz lugimzzz deleted the data_augmentation branch September 19, 2022 04:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants