-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add data augmentation strategy #2805
Conversation
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from .base_augment import * |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
整体缺少一个使用的说明文档,可以在整体工作完备之后出一个文档
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的,开发完成后会补充文档和使用教程
("word_homonym.json", "a578c04201a697e738f6a1ad555787d5", | ||
"https://bj.bcebos.com/paddlenlp/data/word_homonym.json") | ||
} | ||
self.stop_words = self._get_data('stop_words') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这些变量定义和函数可以统一标准,如果是内置变量,不希望被访问,可以 _ 开头变成一个半私有变量
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
将不希望访问的变量改为单下划线开头,如self.DATA->self._DATA
'''Calculate number of words for data augmentation''' | ||
if size == 0: | ||
return 0 | ||
aug_percent = self.aug_percent or 0.02 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这块是不是在类初始化的时候aug_percent默认就是0.02
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改默认初始化0.02
aug_percent=None, | ||
aug_min=1, | ||
aug_max=10): | ||
paddle.set_device("cpu") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里没有太明白为什么paddle.set_device('cpu')
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已删除
'''Read data as list ''' | ||
fullname = self._load_file(mode) | ||
data = [] | ||
with open(fullname, 'r', encoding='utf-8') as f: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里要不要验证一下这个file是否存在
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已添加if os.path.exists(fullname):
|
||
def _generate_random_index(self, seq_tokens, skip=True): | ||
'''Random sample words for insertion/deletion/swap''' | ||
# skip stopping words |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
skip -> Skip
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
aug_n = min(aug_n, len(indexes)) | ||
return random.sample(indexes, aug_n) | ||
|
||
def augment(self, sequences, num_thread=1): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个函数是一个对外public函数,需要把函数和参数写清楚
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已加上
for sequence in sequences: | ||
output.append(self._augment(sequence)) | ||
return output | ||
# TO BE DONE: Multi Thread |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
多进程确实要考虑一下,目前windows的多进程需要多测试一下,有些坑 https://segmentfault.com/a/1190000013681586
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
to be done
# limitations under the License. | ||
import random | ||
|
||
from paddlenlp.data_augmentation import BaseAugment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
如果是library里面,建议import关系,直接import 类似这种 ..data_augmentation import BaseAugment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改为from ..data_augmentation import BaseAugment
return indexes | ||
|
||
|
||
if __name__ == '__main__': |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
library里面的函数不建议加main函数
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已删除
seq_tokens = self.tokenizer.cut(sequence) | ||
aug_n = self._get_aug_n(len(seq_tokens)) | ||
aug_indexes = self.skip_words(seq_tokens) | ||
aug_n = min(aug_n, len(aug_indexes)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这块需要讨论一下,如果字符串skip words之后,剩下的字符过少时,是不是就不用skip了
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
加入策略,被增强的词数量aug_n不得大于len(aug_indexes)*0.3,也即至少每四个词才能有一个词使用数据增强策略
def _get_aug_n(self, size, size_a=None):
if size == 0:
return 0
aug_n = self.aug_n or int(math.ceil(self.aug_percent * size))
if self.aug_min and aug_n < self.aug_min:
aug_n = self.aug_min
elif self.aug_max and aug_n > self.aug_max:
aug_n = self.aug_max
if size_a is not None:
aug_n = min(aug_n, int(math.floor(size_a*0.3)))
return aug_n
fullname = self.custom_file_path | ||
elif source_type in ['delete']: | ||
fullname = self.delete_file_path | ||
with open(fullname, 'r', encoding='utf-8') as f: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
判断一下文件是否存在吧
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已加if os.path.exists(fullname):
|
||
|
||
if __name__ == '__main__': | ||
aug = WordInsert(aug_type='synonym', create_n=2, aug_n=1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
main函数去掉
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
去掉了
idxes = random.sample(list(range(len(candidate_tokens))), | ||
aug_n) | ||
aug_tokens = [] | ||
for idx in idxes: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这块不整体sample一次,单次一次次sample,random.sample这块的耗时可能是一个性能瓶颈
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
random对速度的影响,开发完成后进行数据测评再选择优化方案
aug_indexes = random.sample(aug_indexes, aug_n) | ||
for aug_index in aug_indexes: | ||
token = self.vocab.to_tokens( | ||
random.randint(0, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
self._reverse_sequence(seq_tokens.copy(), [aug_token])) | ||
return sentences | ||
|
||
def _reverse_sequence(self, output_seq_tokens, aug_tokens): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个函数名字为啥叫reverse了,好像和反转没有关系
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
函数改名为_generate_sequence
aug = WordSwap(create_n=2, aug_n=1) | ||
s1 = '2021年,我再看深度学习领域,无论是自然语言处理、音频信号处理、图像处理、推荐系统,似乎都看到attention混得风生水起,只不过更多时候看到的是它的另一个代号:Transformer。' | ||
|
||
augmented = aug.augment(s1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已删除
if source_type in ['synonym', 'homonym']: | ||
fullname = self._load_file('word_' + source_type) | ||
elif source_type in ['custom']: | ||
fullname = self.custom_file_path |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已加if os.path.exists(fullname):
elif source_type in ['delete']: | ||
fullname = self.delete_file_path | ||
with open(fullname, 'r', encoding='utf-8') as f: | ||
substitue_dict = json.load(f) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
判断一下file是否存在
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已加if os.path.exists(fullname):
aug_indexes = random.sample(aug_indexes, aug_n) | ||
for aug_index in aug_indexes: | ||
token = self.vocab.to_tokens( | ||
random.randint(0, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
import paddle | ||
from paddle.dataset.common import md5file | ||
from paddle.utils.download import get_path_from_url | ||
from paddlenlp.utils.env import DATA_HOME |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里paddlenlp的内容,看要不要用相对路径导入
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
from ..utils.env import DATA_HOME
from ..data import Vocab, JiebaTokenizer
aug_percent=None, | ||
aug_min=1, | ||
aug_max=10): | ||
paddle.set_device("cpu") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
注意状态恢复,目前dataloader中,是默认 cpu 环境
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已删除
|
||
|
||
if __name__ == '__main__': | ||
aug = WordDelete(create_n=10, aug_n=1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同泽阳,可以考虑放单测中
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR types
Others
PR changes
APIs
Description
新增基于词表的词级别替换(基于同义词、同音词、随机词、本地词表、组合词表)、删除(随机)、插入(基于同义词、同音词、本地词表、随机词、组合词表)、交换(随机)的数据增强策略。
模型功能: