Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding Conformer model #1327

Merged
merged 75 commits into from Nov 3, 2020
Merged
Show file tree
Hide file tree
Changes from 64 commits
Commits
Show all changes
75 commits
Select commit Hold shift + click to select a range
543466b
Added initial code of Conformer.
VahidooX Oct 21, 2020
99e9233
Added initial code of Conformer.
VahidooX Oct 21, 2020
428a72d
Added log_every_n_steps.
VahidooX Oct 21, 2020
56770ad
dropped older multi_head_att modules.
VahidooX Oct 21, 2020
7213125
Dropped dropout_in and params.
VahidooX Oct 21, 2020
a4c4c79
Fixed code style.
VahidooX Oct 21, 2020
2c900aa
Updated docs.
VahidooX Oct 22, 2020
ac86cb3
Removed unused import.
VahidooX Oct 22, 2020
75aa1e4
fixed docs.
VahidooX Oct 22, 2020
25fff7f
Fixed license header.
VahidooX Oct 22, 2020
09d217c
Fixed license header.
VahidooX Oct 22, 2020
cda38ae
fixed style.
VahidooX Oct 22, 2020
c24b0f6
Updated tests with _target_.
VahidooX Oct 22, 2020
7a61905
Updated tests with _target_.
VahidooX Oct 22, 2020
9420578
Fixed style.
VahidooX Oct 22, 2020
ef7a236
Merge branch 'main' of https://github.com/NVIDIA/NeMo into add_confor…
VahidooX Oct 22, 2020
3402668
Fixed style.
VahidooX Oct 22, 2020
9c4993d
Merge branch 'main' of https://github.com/NVIDIA/NeMo into add_confor…
VahidooX Oct 22, 2020
fb324d0
Fixed missing params.
VahidooX Oct 22, 2020
93adc1a
Fixed missing params.
VahidooX Oct 23, 2020
2473812
Dropped u and v biases.
VahidooX Oct 23, 2020
1653599
Fixed params.
VahidooX Oct 23, 2020
81dff89
moved back padding in features.py.
VahidooX Oct 23, 2020
ea93557
fixed optimzier.
VahidooX Oct 26, 2020
f6787c0
Merge branch 'main' of https://github.com/NVIDIA/NeMo into add_confor…
VahidooX Oct 26, 2020
7c0fe36
Merge branch 'main' of https://github.com/NVIDIA/NeMo into add_confor…
VahidooX Oct 26, 2020
a7c7207
fixed vocab bug.
VahidooX Oct 26, 2020
3ceeebd
added load_weights_from_checkpoint.
VahidooX Oct 27, 2020
acfb8e6
Added jenkins test for Conformer. Updated names from bpe to subword.
VahidooX Oct 27, 2020
fd8764a
Added jenkins test for Conformer. Updated names from bpe to subword.
VahidooX Oct 27, 2020
5f241d4
enabled ddp.
VahidooX Oct 27, 2020
f0cae87
removed extra prints.
VahidooX Oct 28, 2020
b625c95
reorg the folders.
VahidooX Oct 29, 2020
09ee00f
reorg the folders.
VahidooX Oct 30, 2020
cdfd2ec
reorg the folders.
VahidooX Oct 30, 2020
b0f22cf
reorg the folders.
VahidooX Oct 30, 2020
175cd48
reverted back subword.
VahidooX Oct 30, 2020
0efb275
reverted back subword.
VahidooX Oct 30, 2020
fe40839
Merge branch 'main' of https://github.com/NVIDIA/NeMo into add_confor…
VahidooX Oct 30, 2020
c6b5b3c
updated test evaluation.
VahidooX Oct 30, 2020
e60bc15
Updated the code.
VahidooX Oct 30, 2020
8ea849b
fixed the style.
VahidooX Oct 30, 2020
be08878
Added docstring.
VahidooX Oct 30, 2020
d1ed640
dropped load_weights.
VahidooX Oct 30, 2020
c5922ac
fixed feat_out.
VahidooX Oct 30, 2020
58b85f0
fixed feat_out.
VahidooX Oct 30, 2020
42fea11
dropped the vocab files.
VahidooX Oct 30, 2020
c28b312
fixed the bug.
VahidooX Oct 30, 2020
4283e97
added logging of config.
VahidooX Oct 30, 2020
83ef7e1
moved swish to activations.
VahidooX Oct 30, 2020
2ccbb6a
fixed to_yaml.
VahidooX Oct 30, 2020
4d04686
fixed the feat_in bug.
VahidooX Oct 30, 2020
927f82b
Merge branch 'main' of https://github.com/NVIDIA/NeMo into add_confor…
VahidooX Oct 30, 2020
7d859fe
fixed the feat_in bug.
VahidooX Oct 31, 2020
a164748
fixed the feat_in bug.
VahidooX Oct 31, 2020
8c92c54
fixed the feat_in bug.
VahidooX Oct 31, 2020
d929c1d
fixed the feat_in bug.
VahidooX Oct 31, 2020
b4a28f9
Merge branch 'main' of https://github.com/NVIDIA/NeMo into add_confor…
VahidooX Oct 31, 2020
f336ed9
fixed the feat_in bug.
VahidooX Oct 31, 2020
dbfe71e
fixed the feat_in bug.
VahidooX Oct 31, 2020
5a0cf54
fixed the feat_in bug.
VahidooX Oct 31, 2020
6791136
fixed the feat_in bug.
VahidooX Oct 31, 2020
dfeafd1
fixed the feat_in bug.
VahidooX Nov 2, 2020
a74a327
fixed log_predciction.
VahidooX Nov 2, 2020
6600de9
fixed log bug.
VahidooX Nov 2, 2020
4b314db
fixed style.
VahidooX Nov 2, 2020
58f9588
added pos_emb_max_len.
VahidooX Nov 2, 2020
306ddf6
added pos_emb_max_len.
VahidooX Nov 2, 2020
7dabb78
moved subsampling.
VahidooX Nov 2, 2020
90cb95c
added open_dict()
VahidooX Nov 2, 2020
353e652
moved conformerblock.
VahidooX Nov 3, 2020
5e974b8
fixed import bug.
VahidooX Nov 3, 2020
3eef062
fixed code style.
VahidooX Nov 3, 2020
9c1a652
Merge branch 'main' of https://github.com/NVIDIA/NeMo into add_confor…
VahidooX Nov 3, 2020
302b9f3
added conformerencoder to all.
VahidooX Nov 3, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
19 changes: 17 additions & 2 deletions Jenkinsfile
Expand Up @@ -209,10 +209,10 @@ pipeline {
}
}

stage('L2: Speech to Text WPE') {
stage('L2: Speech to Text WPE - CitriNet') {
steps {
sh 'python examples/asr/speech_to_text_bpe.py \
--config-path="experimental/configs/" --config-name="config_bpe" \
--config-path="experimental/citrinet/" --config-name="config_bpe" \
model.train_ds.manifest_filepath=/home/TestData/an4_dataset/an4_train.json \
model.validation_ds.manifest_filepath=/home/TestData/an4_dataset/an4_val.json \
model.tokenizer.dir="/home/TestData/asr_tokenizers/an4_wpe_128/" \
Expand All @@ -223,6 +223,21 @@ pipeline {
sh 'rm -rf examples/asr/speech_to_text_wpe_results'
}
}

stage('L2: Speech to Text WPE - Conformer') {
steps {
sh 'python examples/asr/speech_to_text_bpe.py \
--config-path="experimental/conformer" --config-name="conformer_bpe" \
model.train_ds.manifest_filepath=/home/TestData/an4_dataset/an4_train.json \
model.validation_ds.manifest_filepath=/home/TestData/an4_dataset/an4_val.json \
model.tokenizer.dir="/home/TestData/asr_tokenizers/an4_wpe_128/" \
model.tokenizer.type="wpe" \
trainer.gpus=[1] \
+trainer.fast_dev_run=True \
exp_manager.exp_dir=examples/asr/speech_to_text_wpe_conformer_results'
sh 'rm -rf examples/asr/speech_to_text_wpe_conformer_results'
}
}
}
}

Expand Down
162 changes: 162 additions & 0 deletions examples/asr/experimental/conformer/conformer_bpe.yaml
@@ -0,0 +1,162 @@
name: &name "Conformer-BPE"

model:
sample_rate: &sample_rate 16000
log_prediction: true
load_weights_from_checkpoint: null
ctc_reduction: 'mean_batch'

train_ds:
manifest_filepath: ???
sample_rate: 16000
batch_size: 16
trim_silence: false
max_duration: 16.7
min_duration: 0.1
shuffle: true
is_tarred: false
tarred_audio_filepaths: null
num_workers: 4
pin_memory: false
use_start_end_token: true

validation_ds:
manifest_filepath: ???
sample_rate: 16000
batch_size: 16
shuffle: false
num_workers: 4
pin_memory: false
use_start_end_token: true

test_ds:
manifest_filepath: null
sample_rate: 16000
batch_size: 16
shuffle: false
num_workers: 4
pin_memory: false
use_start_end_token: true

tokenizer:
dir: ??? # path to directory which contains either tokenizer.model (bpe) or vocab.txt (for wpe)
type: ??? # Can be either bpe or wpe

preprocessor:
_target_: nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor
sample_rate: *sample_rate
normalize: "per_feature"
window_size: 0.025
window_stride: 0.01
window: "hann"
features: &n_mels 80
n_fft: 512
frame_splicing: 1
dither: 0.00001
pad_to: 16
stft_conv: false

spec_augment:
_target_: nemo.collections.asr.modules.SpectrogramAugmentation
# SpecAug params
freq_masks: 2 # set to zero to disable the SpecAug augmentation
time_masks: 2 # set to zero to disable the SpecAug augmentation
freq_width: 27
time_width: 100
# Cut-off params
rect_masks: 0 # set to zero to disable the cut-off augmentation
rect_time: 120
rect_freq: 50


encoder:
_target_: nemo.collections.asr.modules.ConformerEncoder
feat_in: *n_mels
feat_out: -1 # you may set it if you need different output size other than the default d_model
n_layers: 16
d_model: 256

# Sub-sampling params
subsampling: vggnet # vggnet or striding
subsampling_factor: 4 # must be power of 2
subsampling_conv_channels: 64 # set to -1 to make it equal to the d_model

# Feed forward module's params
ff_expansion_factor: 4

# Multi-headed Attention Module's params
self_attention_model: rel_pos # rel_pos, abs_pos
n_heads: 4
xscaling: true

# Convolution module's params
conv_kernel_size: 31

### regularization
dropout: 0.1 # The dropout used inside the Conformer Modules
dropout_emb: 0.1 # The dropout used embeddings
dropout_att: 0.0 # The dropout for multi-headed attention modules

decoder:
_target_: nemo.collections.asr.modules.LSTMDecoder
feat_in: null # If not provided, the feat_out of the encoder would be used
num_classes: -1 # filled with vocabulary size from tokenizer at runtime
vocabulary: [] # filled with vocabulary from tokenizer at runtime
lstm_hidden_size: 640
bidirectional: False
num_layers: 1

optim:
name: novograd
lr: 0.01
# optimizer arguments
betas: [0.8, 0.5]
weight_decay: 0.001

# scheduler setup
sched:
name: CosineAnnealing
# scheduler config override
warmup_steps: 4000
warmup_ratio: null
min_lr: 1e-9
last_epoch: -1

trainer:
gpus: 0 # number of gpus
num_nodes: 1
max_epochs: 100
max_steps: null # computed at runtime if not set
val_check_interval: 1 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations
distributed_backend: ddp
accumulate_grad_batches: 2
gradient_clip_val: 0.0
amp_level: O0 # O1/O2 for mixed precision
precision: 32 # Should be set to 16 for O1 and O2 to enable the AMP.
log_every_n_steps: 10 # Interval of logging.
resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc.
num_sanity_val_steps: 0 # number of steps to perform validation steps for sanity check the validation process before starting the training, setting to 0 disables it
check_val_every_n_epoch: 10 # number of evaluations on validation every n epochs
sync_batchnorm: true
checkpoint_callback: false # Provided by exp_manager
logger: false # Provided by exp_manager


exp_manager:
exp_dir: null
name: *name
create_tensorboard_logger: true
create_checkpoint_callback: true
create_wandb_logger: false
wandb_logger_kwargs:
name: null
project: null
resume_if_exists: false
resume_ignore_no_checkpoint: false

hydra:
run:
dir: .
job_logging:
root:
handlers: null
blisc marked this conversation as resolved.
Show resolved Hide resolved
159 changes: 159 additions & 0 deletions examples/asr/experimental/conformer/conformer_char.yaml
@@ -0,0 +1,159 @@
name: &name "Conformer-char"

model:
sample_rate: &sample_rate 16000
labels: &labels [" ", "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m",
"n", "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z", "'"]
log_prediction: true
load_weights_from_checkpoint: null
ctc_reduction: 'mean_batch'

train_ds:
manifest_filepath: ???
labels: *labels
sample_rate: 16000
batch_size: 16
trim_silence: false
max_duration: 16.7
min_duration: 0.1
shuffle: true
is_tarred: false
tarred_audio_filepaths: null
num_workers: 4
pin_memory: true

validation_ds:
manifest_filepath: ???
labels: *labels
sample_rate: 16000
batch_size: 16
shuffle: false
num_workers: 4
pin_memory: true

test_ds:
manifest_filepath: null
labels: *labels
sample_rate: 16000
batch_size: 16
shuffle: false
num_workers: 4
pin_memory: true

preprocessor:
_target_: nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor
sample_rate: *sample_rate
normalize: "per_feature"
window_size: 0.025
window_stride: 0.01
window: "hann"
features: &n_mels 80
n_fft: 512
frame_splicing: 1
dither: 0.00001
pad_to: 16
stft_conv: false

spec_augment:
_target_: nemo.collections.asr.modules.SpectrogramAugmentation
# SpecAug params
freq_masks: 2 # set to zero to disable the SpecAug augmentation
time_masks: 2 # set to zero to disable the SpecAug augmentation
freq_width: 27
time_width: 100
# Cut-off params
rect_masks: 0 # set to zero to disable the cut-off augmentation
rect_time: 120
rect_freq: 50

encoder:
_target_: nemo.collections.asr.modules.ConformerEncoder
feat_in: *n_mels
feat_out: -1 # you may set it if you need different output size other than the default d_model
n_layers: 16
d_model: 256

# Sub-sampling params
subsampling: vggnet # vggnet or striding
subsampling_factor: 4 # must be power of 2
subsampling_conv_channels: 64 # set to -1 to make it equal to the d_model

# Feed forward module's params
ff_expansion_factor: 4

# Multi-headed Attention Module's params
self_attention_model: rel_pos # rel_pos, abs_pos
n_heads: 4
xscaling: true

# Convolution module's params
conv_kernel_size: 31

### regularization
dropout: 0.1 # The dropout used inside the Conformer Modules
dropout_emb: 0.1 # The dropout used embeddings
dropout_att: 0.0 # The dropout for multi-headed attention modules

decoder:
_target_: nemo.collections.asr.modules.LSTMDecoder
feat_in: null # If not provided, the feat_out of the encoder would be used
num_classes: 28
vocabulary: *labels
lstm_hidden_size: 640
bidirectional: False
num_layers: 1

optim:
name: novograd
lr: 0.01
# optimizer arguments
betas: [0.8, 0.5]
weight_decay: 0.001

# scheduler setup
sched:
name: CosineAnnealing
# scheduler config override
warmup_steps: 1000
warmup_ratio: null
min_lr: 1e-9
last_epoch: -1

trainer:
gpus: 0 # number of gpus
num_nodes: 1
max_epochs: 100
max_steps: null # computed at runtime if not set
val_check_interval: 1 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations
distributed_backend: ddp
accumulate_grad_batches: 2
gradient_clip_val: 0.0
amp_level: O0 # O1/O2 for mixed precision
precision: 32 # Should be set to 16 for O1 and O2 to enable the AMP.
log_every_n_steps: 10 # Interval of logging.
resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc.
num_sanity_val_steps: 0 # number of steps to perform validation steps for sanity check the validation process before starting the training, setting to 0 disables it
check_val_every_n_epoch: 10 # number of evaluations on validation every n epochs
sync_batchnorm: true
checkpoint_callback: false # Provided by exp_manager
logger: false # Provided by exp_manager


exp_manager:
exp_dir: null
name: *name
create_tensorboard_logger: true
create_checkpoint_callback: true
create_wandb_logger: false
wandb_logger_kwargs:
name: null
project: null
resume_if_exists: false
resume_ignore_no_checkpoint: false

hydra:
run:
dir: .
job_logging:
root:
handlers: null
blisc marked this conversation as resolved.
Show resolved Hide resolved