Skip to content

Commit

Permalink
CTC-only training recipes for LibriSpeech (code from Samsung AI Cambr…
Browse files Browse the repository at this point in the history
…idge) (speechbrain#2290)

CTC-only pre-training of conformer and branchformer.

---------

Co-authored-by: Shucong Zhang/Embedded AI /SRUK/Engineer/Samsung Electronics <s1.zhang@sruk-ccn4.eu.corp.samsungelectronics.net>
Co-authored-by: Adel Moumen <adelmoumen.pro@gmail.com>
Co-authored-by: Adel Moumen <88119391+Adel-Moumen@users.noreply.github.com>
Co-authored-by: Parcollet Titouan <titouan.parcollet@univ-avignon.fr>
  • Loading branch information
5 people committed Apr 18, 2024
1 parent 608d1d5 commit d086cde
Show file tree
Hide file tree
Showing 9 changed files with 1,056 additions and 7 deletions.
23 changes: 20 additions & 3 deletions recipes/LibriSpeech/ASR/CTC/README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# LibriSpeech ASR with CTC and pre-trained wav2vec2 or whisper models.
# LibriSpeech ASR with CTC only or pre-trained wav2vec2 or whisper models.
This folder contains the scripts to finetune a wav2vec2 or a whisper based system using LibriSpeech.
You can download LibriSpeech at http://www.openslr.org/12.
The loss function is the CTC loss and it is implemented in two different ways:
Expand All @@ -17,17 +17,31 @@ pip install -r extra_requirements.txt

# How to run
```
python train.py hparams/file.yaml
python train_with_wav2vec.py hparams/file.yaml
```
```
python train_with_whisper.py hparams/file.yaml
```
To run a fine-tuning of "WavLM" with signal downsampled inputs (for faster training and inferences)

```
python train_with_wav2vec.py hparams/downsampled/train_hf_wavlm_signal_downsampling.yaml --downsampling_factor 2
```
To train a model from scratch (without any pre-training), please firstly go to the Tokenizer folder to train a tokenizer:

```
cd ../../Tokenizer
python train.py hparams/128_bpe.yaml
```
Then, go back to this directory. You can train a Branchformer CTC model with:

```
python train.py hparams/train_branchformer.yaml
```
or a Conformer CTC model with:

```
python train.py hparams/train_conformer.yaml
```
# WFST-based CTC loss
To fine-tune a wav2vec 2.0 model with the WFST-based CTC loss, you can use the `train_with_wav2vec_k2.py` script. This will create a `lang` directory inside your output folder, which will contain the files required to build a lexicon FST. The tokenization method used here is a very basic character-based tokenization (e.g. `hello -> h e l l o`).

Expand Down Expand Up @@ -84,6 +98,9 @@ Note: by default, `topk` is set to 20 as it gives a good trade-off between WER a
| 23-01-24 | train_hf_wav2vec_k2.yaml | k2CTC + HLG graph + whole lattice rescoring + test batch size = 1 | 960h | 1.81 | Not Avail. | 3.57 | Not Avail. | Not Avail. | [Link](https://www.dropbox.com/scl/fo/kj2ujqj3votq7ue6ydh0l/h?rlkey=mibyoria19zasvuxs0iwx6plt&dl=0) | 1xRTX2080Ti 12GB | 1xRTX2080Ti 12GB |
| 08-12-23 | train_hf_wav2vec.yaml | CTCBeamSearch + RNNLM Rescorer + test batch size = 1 + topk = 100 | 960h | 1.69 | 26mins15 | 3.55 | 32min44s | Not Avail. | [Link](https://www.dropbox.com/sh/k4ixa211yp5b1tm/AAD85sgYw2CH7NKk_qKMO9Tja?dl=0) | 1x A100 40GB | 2xTesla V100 40GB |
| 08-12-23 | train_hf_wav2vec.yaml | CTCBeamSearch + TransformerLM Rescorer + test batch size = 1 + topk = 100 | 960h | 1.57 | 26mins56s | 3.37 | 32min46 | Not Avail. | [Link](https://www.dropbox.com/sh/ijqalvre7mm08ng/AAD_hsN-8dBneUMMkELsOOxga?dl=0) | 1x A100 40GB | 2xTesla V100 32GB |
| 06-12-23 | train_branchformer.yaml (25.9M) | 960h | 3.6 (no LM) | Not Avail. | Not Avail. | 8xA40 46G |
| 06-12-23 | train_conformer.yaml (28.8M) | 960h | 3.7 (no LM) | Not Avail. | Not Avail. | 8xA40 46G |


# Downsampling inputs for faster fine-tuning and inferences using SSL Models
This repository contains the code allowing to reproduce part of the results obtained in the paper : "Fine-tuning Strategies for Faster Inference using Speech Self-Supervised Models: A Comparative Study"
Expand Down
252 changes: 252 additions & 0 deletions recipes/LibriSpeech/ASR/CTC/hparams/branchformer_large.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,252 @@
# ############################################################################
# Model: E2E ASR with CTC
# Encoder: Branchformer Encoder
# Decoder: CTC beam searcher and greedy searcher
# Tokens: character
# Training: Librispeech 960h
# Authors: Titouan Parcollet, Shucong Zhang, Adel Moumen
# ############################################################################
# Seed needs to be set at top of yaml, before objects with parameters are made

seed: 3402
__set_seed: !apply:torch.manual_seed [!ref <seed>]
output_folder: !ref results/branchformer_ctc/
wer_file: !ref <output_folder>/wer.txt
save_folder: !ref <output_folder>/save
train_log: !ref <output_folder>/train_log.txt

# Data files
data_folder: !PLACEHOLDER # e.g., /path/to/LibriSpeech
# If RIRS_NOISES dir exists in /localscratch/xxx_corpus/RIRS_NOISES
# then data_folder_rirs should be /localscratch/xxx_corpus
# otherwise the dataset will automatically be downloaded
# data_folder_rirs: !ref <data_folder>
train_splits: ["train-clean-100", "train-clean-360", "train-other-500"]
dev_splits: ["dev-clean"]
test_splits: ["dev-clean", "test-clean", "test-other"]
skip_prep: False
train_csv: !ref <output_folder>/train.csv
valid_csv: !ref <output_folder>/dev-clean.csv
test_csv:
- !ref <output_folder>/dev-clean.csv
- !ref <output_folder>/test-clean.csv
- !ref <output_folder>/test-other.csv

####################### Training Parameters ####################################

number_of_epochs: 500
batch_size: 16 # This works for 2x GPUs with 32GB
grad_accumulation_factor: 2
max_grad_norm: 5.0
sorting: descending #random
num_workers: 8
loss_reduction: batchmean
valid_search_interval: 1
avg_checkpoints: 10 # Number of checkpoints to average for evaluation

lr_model: 0.001
weight_decay: 0.0005

# Feature parameters
sample_rate: 16000
n_fft: 512
n_mels: 80
win_length: 25

# Training parameters
# To make Transformers converge, the global bath size should be large enough.
# The global batch size is max_batch_len * n_gpus * gradient_accumulation.
# Empirically, we used 850 * 8 A40 45G GPUs * 2 or 1700 * 4 A100 80G * 2.
# Please, set your parameters accordingly.
dynamic_batching: True
max_batch_length_train: 850
max_batch_len_val: 100
num_bucket: 200
shuffle: False # if true re-creates batches at each epoch shuffling examples.
max_batch_ex: 128
batch_ordering: random

dynamic_batch_sampler_train:
max_batch_length: !ref <max_batch_length_train>
num_buckets: !ref <num_bucket>
shuffle: !ref <shuffle>
batch_ordering: !ref <batch_ordering>
max_batch_ex: !ref <max_batch_ex>

dynamic_batch_sampler_val:
max_batch_length: !ref <max_batch_len_val>
num_buckets: !ref <num_bucket>
shuffle: !ref <shuffle>
batch_ordering: !ref <batch_ordering>
max_batch_ex: !ref <max_batch_ex>

# Dataloader options
train_dataloader_opts:
batch_size: !ref <batch_size>
shuffle: True
num_workers: !ref <num_workers>

valid_dataloader_opts:
batch_size: 1

test_dataloader_opts:
batch_size: 1

####################### Model parameters ###########################

# Transformer
attention_type: RelPosMHAXL
d_model: 256
nhead: 4
csgu_linear_units: 2400
csgu_kernel_size: 31
num_encoder_layers: 18
num_decoder_layers: 0
transformer_dropout: 0.1
activation: !name:torch.nn.GELU
output_neurons: 31

# BPE parameters
token_type: char # ["unigram", "bpe", "char"]
character_coverage: 1.0
blank_index: 0
bos_index: 1
eos_index: 2

# Decoding parameters
beam_size: 100
beam_prune_logp: -12.0
token_prune_min_logp: -1.2
prune_history: False

############################## models ################################

CNN: !new:speechbrain.lobes.models.convolution.ConvolutionFrontEnd
input_shape: (8, 10, 80)
num_blocks: 2
num_layers_per_block: 1
out_channels: (64, 32)
kernel_sizes: (3, 3)
strides: (2, 2)
residuals: (False, False)

Transformer: !new:speechbrain.lobes.models.transformer.TransformerASR.TransformerASR # yamllint disable-line rule:line-length
input_size: 640
tgt_vocab: !ref <output_neurons>
d_model: !ref <d_model>
nhead: !ref <nhead>
num_encoder_layers: !ref <num_encoder_layers>
num_decoder_layers: !ref <num_decoder_layers>
dropout: !ref <transformer_dropout>
activation: !ref <activation>
encoder_module: branchformer
attention_type: !ref <attention_type>
normalize_before: True
causal: False
csgu_linear_units: !ref <csgu_linear_units>
kernel_size: !ref <csgu_kernel_size>

ctc_lin: !new:speechbrain.nnet.linear.Linear
input_size: !ref <d_model>
n_neurons: !ref <output_neurons>

normalize: !new:speechbrain.processing.features.InputNormalization
norm_type: global
update_until_epoch: 4

modules:
CNN: !ref <CNN>
Transformer: !ref <Transformer>
ctc_lin: !ref <ctc_lin>
normalize: !ref <normalize>

model: !new:torch.nn.ModuleList
- [!ref <CNN>, !ref <Transformer>, !ref <ctc_lin>]

####################### Decoding & optimiser ###########################

# Decoding parameters
test_beam_search:
blank_index: !ref <blank_index>
beam_size: !ref <beam_size>
beam_prune_logp: !ref <beam_prune_logp>
token_prune_min_logp: !ref <token_prune_min_logp>
prune_history: !ref <prune_history>

ctc_cost: !name:speechbrain.nnet.losses.ctc_loss
blank_index: !ref <blank_index>
reduction: !ref <loss_reduction>

noam_annealing: !new:speechbrain.nnet.schedulers.LinearNoamScheduler
lr_initial: !ref <lr_model>
n_warmup_steps: 7500
n_keep_steps: 36000

model_opt_class: !name:torch.optim.AdamW
lr: !ref <lr_model>
betas: (0.9, 0.98)
eps: 0.000000001
weight_decay: !ref <weight_decay>

log_softmax: !new:torch.nn.LogSoftmax
dim: -1

############################## Augmentations ###################################

speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
orig_freq: !ref <sample_rate>
speeds: [95, 100, 105]

drop_freq: !new:speechbrain.augment.time_domain.DropFreq
drop_freq_low: 0
drop_freq_high: 1
drop_freq_count_low: 1
drop_freq_count_high: 3
drop_freq_width: 0.05

drop_chunk: !new:speechbrain.augment.time_domain.DropChunk
drop_length_low: 1000
drop_length_high: 2000
drop_count_low: 1
drop_count_high: 5

# Augmenter: Combines previously defined augmentations to perform data augmentation
wav_augment: !new:speechbrain.augment.augmenter.Augmenter
parallel_augment: False
concat_original: True
repeat_augment: 1
shuffle_augmentations: False
min_augmentations: 4
max_augmentations: 4
augment_prob: 1.0
augmentations: [
!ref <speed_perturb>,
!ref <drop_freq>,
!ref <drop_chunk>
]

compute_features: !new:speechbrain.lobes.features.Fbank
sample_rate: !ref <sample_rate>
n_fft: !ref <n_fft>
win_length: !ref <win_length>
n_mels: !ref <n_mels>

############################## Logging and Pretrainer ##########################

checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
checkpoints_dir: !ref <save_folder>
recoverables:
model: !ref <model>
noam_scheduler: !ref <noam_annealing>
normalizer: !ref <normalize>
counter: !ref <epoch_counter>

epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
limit: !ref <number_of_epochs>

train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
save_file: !ref <train_log>

cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats
split_tokens: True
wer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats
Loading

0 comments on commit d086cde

Please sign in to comment.