Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding WavLM implementation #3242

Merged
merged 11 commits into from
Jun 1, 2023
Merged

Conversation

jiamingkong
Copy link
Contributor

@jiamingkong jiamingkong commented May 15, 2023

PR types

New features

PR changes

Models

Describe

This PR implements the WavLM model as in https://arxiv.org/abs/2110.13900 for speech recognition. On Librispeech clean set, the model finetuned from wavlm-base-plus has a WER of 6.0%.

image

复现效果:

在Librispeech Clean 100小时数据集上,不使用语言模型:

模型 论文精度 paddle复现精度 torch复现精度
wavlm-base 5.7% 5.8% 6.8%
wavlm-base-plus 4.7% 5.6% -

对wavlm-base模型的paddle复现可以接近论文效果,优于torch实现。

torch 复现链接:https://huggingface.co/patrickvonplaten/wavlm-libri-clean-100h-base-plus

对齐效果

  1. 如果使用上面torch复现的超参数进行训练,paddle版本可以得到6.9%左右精度,与torch一致。
  2. 为了实现论文中的效果,在paddle复现中,我们改动了模型:
    • 结构上,wavlm 编码后,后面衔接了三层MLP + BatchNorm + 激活函数(和paddlespeech中的wav2vec2ASR实现一致)
    • 训练上,wavlm权重有一个带有预热的更小学习率的优化器,三层MLP有学习率正常的优化器

以上两个优化可以训练出精度接近论文描述的模型。

@paddle-bot
Copy link

paddle-bot bot commented May 15, 2023

Thanks for your contribution!

@zxcd
Copy link
Collaborator

zxcd commented May 16, 2023

conv_layers.py空文件删一删

Stage 0 also downloads the pre-trained [hubert](https://paddlespeech.bj.bcebos.com/hubert/hubert-large-lv60.pdparams) model.
```bash
mkdir -p exp/hubert
wget -P exp/hubert https://paddlespeech.bj.bcebos.com/hubert/hubert-large-lv60.pdparams
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

模型链接改掉

@@ -0,0 +1,18 @@
#!/usr/bin/env python3
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个文件可以删除,应该path.sh配置路径吧

@@ -0,0 +1,558 @@
# Copyright 2021 Mobvoi Inc. All Rights Reserved.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

删除,新增path.sh文件

@@ -0,0 +1,143 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

examples/librispeech/asr5/test.profile Outdated Show resolved Hide resolved
# from paddlespeech.utils.argparse import print_arguments
import distutils.util

def add_arguments(argname, type, default, help, argparser, **kwargs):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以复用已有的函数。

help=help + ' Default: %(default)s.',
**kwargs)

def print_arguments(args, info=None):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上。


import distutils.util

def add_arguments(argname, type, default, help, argparser, **kwargs):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

help=help + ' Default: %(default)s.',
**kwargs)

def print_arguments(args, info=None):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

}


def get_activation(activation_string):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是否可以服用modules目录里的

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

空文件

@jiamingkong
Copy link
Contributor Author

收到,上述的内容我会修改完成。

我先附带上权重:

链接:https://pan.baidu.com/s/1Yjv1rITAWeYv-MjRJD-PPg?pwd=wavl
提取码:wavl

@zh794390558
Copy link
Collaborator

examples/librispeech/asr5/format_rsl.py 可以删除,在utils目录下有

@jiamingkong
Copy link
Contributor Author

jiamingkong commented May 25, 2023

复现效果:

在Librispeech Clean 100小时数据集上,不使用语言模型:

模型 论文精度 paddle复现精度 torch复现精度
wavlm-base 5.7% 5.8% 6.8%
wavlm-base-plus 4.7% 5.6% -

对wavlm-base模型的paddle复现可以接近论文效果,优于torch实现。

torch 复现链接:https://huggingface.co/patrickvonplaten/wavlm-libri-clean-100h-base-plus

对齐效果

  1. 如果使用上面torch复现的超参数进行训练,paddle版本可以得到6.9%左右精度,与torch一致。
  2. 为了实现论文中的效果,在paddle复现中,我们改动了模型:
    • 结构上,wavlm 编码后,后面衔接了三层MLP + BatchNorm + 激活函数(和paddlespeech中的wav2vec2ASR实现一致)
    • 训练上,wavlm权重有一个带有预热的更小学习率的优化器,三层MLP有学习率正常的优化器

以上两个优化可以训练出精度接近论文描述的模型。

paddlespeech/s2t/models/wavlm/modules/functional.py Outdated Show resolved Hide resolved
paddlespeech/s2t/models/wavlm/modules/functional.py Outdated Show resolved Hide resolved
return out


def addr(input, vec1, vec2, beta=1, alpha=1, out=None):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这是算什么的?加下doc-string.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已处理,这个是计算alpha * (vec1 * vec2.T) + beta * input的帮手函数,用于attention的QK计算

paddlespeech/s2t/models/wavlm/modules/modules.py Outdated Show resolved Hide resolved
# normal_(module.v_proj.weight.data)


def quant_noise(module, p, block_size):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个有用吗?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

有的,如果要从头预训练wavlm则需要这个函数

paddlespeech/s2t/models/wavlm/wavlm_paddle.py Show resolved Hide resolved
zh794390558
zh794390558 previously approved these changes May 31, 2023
Copy link
Collaborator

@zh794390558 zh794390558 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@@ -0,0 +1,9 @@
{
"do_normalize": true,
"feature_extractor_type": "Wav2Vec2FeatureExtractor",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个文件是否有用到?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

应该是有的,feature extractor是取出音频,然后减去了平均值,得到一个mean = 0 的[ time * 16000, 1]的tensor

logger = Log(__name__).getlog()


class Wav2vec2Infer():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wav2vec2?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

改好了

. ./path.sh || exit 1;
. ./cmd.sh || exit 1;

gpus=1,2,3
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

最好能从0开始

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

改好了

Copy link
Collaborator

@zh794390558 zh794390558 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

zh794390558
zh794390558 previously approved these changes May 31, 2023
Copy link
Collaborator

@zh794390558 zh794390558 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Collaborator

@zxcd zxcd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Collaborator

@zh794390558 zh794390558 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@zh794390558 zh794390558 merged commit 2214c0d into PaddlePaddle:develop Jun 1, 2023
5 of 7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants