-
Notifications
You must be signed in to change notification settings - Fork 1.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ASR] Support Hubert, fintuned on the librispeech dataset #3088
Conversation
Thanks for your contribution! |
This pull request is now in conflict :( |
@@ -55,6 +58,8 @@ def __init__(self, config: dict): | |||
reduction='mean') | |||
|
|||
def forward(self, wav, wavs_lens_rate, target, target_lens): | |||
# import pdb | |||
# pdb.set_trace() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
注释可以删一下
examples/librispeech/asr3/run.sh
Outdated
@@ -19,6 +20,7 @@ audio_file=data/demo_002_en.wav | |||
|
|||
avg_ckpt=avg_${avg_num} | |||
ckpt=$(basename ${conf_path} | awk -F'.' '{print $1}') | |||
ckpt=test6 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
删一下
dataset/librispeech/librispeech.py
Outdated
@@ -133,7 +133,7 @@ def create_manifest(data_dir, manifest_path): | |||
def prepare_dataset(url, md5sum, target_dir, manifest_path): | |||
"""Download, unpack and create summmary manifest file. | |||
""" | |||
if not os.path.exists(os.path.join(target_dir, "LibriSpeech")): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里为什么要变?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
原先的代码似乎和librispeech解压出的结果不太一致,本地已有librispeech数据集的情况下不太方便
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
什么意思?这里不是有的话就不下载了吗?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok,按照之前的吧
|
||
|
||
task_cfg: | ||
sample_rate: 16000 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
建议能否添加pretrain/finetune的标签
# Data Augmentation # | ||
############################################ | ||
audio_augment: # for raw audio | ||
sample_rate: 16000 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
为什么需要两个sample_rate参数
self.mask_emb = paddle.create_parameter( | ||
shape=[cfg.encoder_embed_dim], | ||
dtype='float32', | ||
default_initializer=paddle.nn.initializer.Uniform(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torch和paddle对于uniform的初始化范围不一致,torch为(0,1),paddle为(-1,1),可以确定下是否会对训练产生影响,或者直接加上low和high参数
self.label_embs_concat = paddle.create_parameter( | ||
shape=[sum(self.num_classes), final_dim], | ||
dtype='float32', | ||
default_initializer=paddle.nn.initializer.Uniform(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
|
||
return x, mask_indices | ||
|
||
def compute_nce(x, pos, negs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self?
from dataclasses import dataclass, field, is_dataclass | ||
from copy import deepcopy | ||
|
||
from omegaconf import II, MISSING, open_dict |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
有用到吗?
|
||
|
||
class HubertBase(nn.Layer): | ||
"""Wav2vec2 model""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
改一改
enc_n_units: 1024 | ||
blank_id: 0 | ||
dropout_rate: 0.0 | ||
hubert_params_path: "exp/hubert/pd_hubert.pdparams" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个模型是否可以给出下载链接?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
细节比较多,先review下,后面再细看。
fp16: True | ||
label_rate: 50 | ||
extractor_mode: layer_norm | ||
encoder_layers: 24 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这是Large的配置?配置文件区分下吧
examples/librispeech/asr3/path.sh
Outdated
@@ -10,6 +10,5 @@ export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH} | |||
|
|||
export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib/ | |||
|
|||
|
|||
MODEL=wav2vec2 | |||
MODEL=$1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个不需要固定,不能用传参的方式。如果是和wav2vec一个asr目录的话就单开个吧。
examples/librispeech/asr3/run.sh
Outdated
stage=0 | ||
stop_stage=0 | ||
conf_path=conf/wav2vec2ASR.yaml | ||
gpus=2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
记得够改回默认值。
logger = Log(__name__).getlog() | ||
|
||
|
||
def clip_grad_norm_( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
替换成paddle的API吧。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
我看wav2vec2目前也用的这个接口?paddle的对应api是哪个?可以用了吗?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dev和最近的2.5有这个API了。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里我注释了todo,后面paddle依赖改为2.5后再改这里吧
self.feat2tar_ratio = cfg.label_rate * feature_ds_rate / task_cfg.sample_rate | ||
|
||
self.post_extract_proj = ( | ||
nn.Linear(self.embed, cfg.encoder_embed_dim) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
需要替换成align.Linaer,相关的都需要操作下。
self.target_glu = None | ||
if cfg.target_glu: | ||
self.target_glu = nn.Sequential( | ||
nn.Linear(final_dim, final_dim * 2), GLU()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
self.target_glu = nn.Sequential( | ||
nn.Linear(final_dim, final_dim * 2), GLU()) | ||
|
||
self.final_proj = nn.Linear(cfg.encoder_embed_dim, final_dim) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
is_group_norm=False, | ||
conv_bias=False, ): | ||
def make_conv(): | ||
conv = nn.Conv1D( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
替换成align
def make_conv_block(e, k, g, l): | ||
return nn.Sequential(*[ | ||
nn.Sequential( | ||
nn.Conv1D( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
替换。
|
||
# layer norm associated with the self attention layer | ||
self.self_attn_layer_norm = LayerNorm(self.embedding_dim) | ||
self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
@@ -0,0 +1,586 @@ | |||
# Copyright (c) Facebook, Inc. and its affiliates. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
少个__init__.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR types
New features
PR changes
Models
Describe
support ASR Hubert