-
Notifications
You must be signed in to change notification settings - Fork 1
/
Train.yaml
250 lines (200 loc) · 7.07 KB
/
Train.yaml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
# #################################
# Basic training parameters for integrated system
#
# Author:
# * Wen 2023
# #################################
# Seed needs to be set at top of yaml, before objects with parameters are made
seed: 929
__set_seed: !!python/object/apply:torch.manual_seed [!ref <seed>]
device: cuda
scp_folder: !ref data/
output_folder: !ref exp_<seed>
save_folder: !ref <output_folder>/save
train_log: !ref <output_folder>/train_log.txt
test_log: !ref <output_folder>/test_log.txt
wer_file: !ref <output_folder>/wer.txt
FWD_DRZ: False
FWD_VAD: False
test_id: 5 # leave-one-session-out 5 fold cross validation
train_annotation: !ref <scp_folder>/iemo-train-<test_id>-MTL.json
valid_annotation: !ref <scp_folder>/iemo-valid-<test_id>-MTL.json
test_annotation: !ref <scp_folder>/iemo-test-<test_id>-MTL.json
fwd_vad_annotation: !ref <scp_folder>/metadata/iemo-test-<test_id>-DiagAll_3-2.json # During testing, dialogue input to the system with window width 3s and overlap 2s
fwd_drz_annotation: !ref <scp_folder>/iemo-test-<test_id>-drz.json
emo_lab: !ref <scp_folder>/iemo-lab-MajCat6.npy
vad_lab: !ref <scp_folder>/iemo_VAD_lab-utt.npy
lab_enc_file_emo: !ref <scp_folder>/label_encoder_emo.txt
lab_enc_file_asr: !ref <scp_folder>/label_encoder_asr.txt
train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
save_file: !ref <train_log>
precision: 4
test_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
save_file: !ref <test_log>
precision: 4
ckpt_interval_minutes: 25 # save checkpoint every N min
dataloader_options:
batch_size: !ref <batch_size>
shuffle: False
collate_fn: !name:speechbrain.dataio.batch.PaddedBatch
padded_keys: sig, tokens,vad_lab
sorting: descending
# Training Parameters
number_of_epochs: 40
batch_size: 4
lr_ER: 0.6
lr_SR: 1.2
lr_SV: 0.2
lr_VAD: 0.4
lr_w2v2: 0.001
lr: 0.8
weight_decay_SR: 0.0001
weight_decay_ER: 0.0025
weight_decay: 0.0001
grad_accumulation_factor: 4
sample_rate: 16000
# wavlm encoder
wav2vec2_hub: microsoft/wavlm-base-plus
freeze_wav2vec: False
freeze_wav2vec_conv: True
output_all_hiddens: True
wav2vec2_output_dim: 768
num_pretrain_layers: 13
# ER head parameters
hidden_dim_ER: 256
output_dim_ER: 6
# SR head parameters
output_dim_SR: 29
num_hidden_SR: 4
hidden_dim_SR: 256
blank_index: 0
# SV head parameters
load_pretrained_SV: True
pretrained_SV_hub: microsoft/wavlm-base-plus-sv
freeze_embedding_model: False
freeze_interface: False
xvector_output_dim: 512
output_dim_spk: 8
diff_lr: True
coeff_ER: 1
coeff_SR: 1
coeff_VAD: 1.2
coeff_SV: 1.2
log_softmax: !new:speechbrain.nnet.activations.Softmax
apply_log: True
ctc_cost: !name:speechbrain.nnet.losses.ctc_loss
blank_index: !ref <blank_index>
compute_cost_CE: !new:torch.nn.CrossEntropyLoss
wer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats
cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats
split_tokens: True
error_stats_Acc: !name:utils.MetricStats_acc
do_augment: False
augmentation: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
sample_rate: !ref <sample_rate>
speeds: [95, 100, 105]
# shared encoder
wav2vec2: !new:speechbrain.lobes.models.huggingface_wav2vec.HuggingFaceWav2Vec2
source: !ref <wav2vec2_hub>
output_norm: True
freeze: !ref <freeze_wav2vec>
freeze_feature_extractor: !ref <freeze_wav2vec_conv>
output_all_hiddens: !ref <output_all_hiddens>
save_path: data/wav2vec2_checkpoint
interface_ER: !new:model.interface
num_pretrain_layers: !ref <num_pretrain_layers>
interface_SR: !new:model.interface
num_pretrain_layers: !ref <num_pretrain_layers>
interface_SV: !new:model.interface
num_pretrain_layers: !ref <num_pretrain_layers>
interface_VAD: !new:model.interface
num_pretrain_layers: !ref <num_pretrain_layers>
enc_ER: !new:model.TransformerModel
input_dim: !ref <wav2vec2_output_dim>
output_dim: !ref <output_dim_ER>
d_model: !ref <hidden_dim_ER>
num_encoder_layers: 4
dp: 0.3
enc_SR: !new:model.RNN_enc
input_dim: !ref <wav2vec2_output_dim>
output_dim: !ref <output_dim_SR>
rnn_blocks: !ref <num_hidden_SR>
rnn_neurons: !ref <hidden_dim_SR>
embedding_model_SV: !new:model.XVector
input_dim: !ref <wav2vec2_output_dim>
xvector_output_dim: !ref <xvector_output_dim>
tdnn_dim: [512, 512, 512, 512, 1500]
tdnn_kernel: [5, 3, 3, 1, 1]
tdnn_dilation: [1, 2, 3, 1, 1]
inited: !ref <load_pretrained_SV>
classifier_SV: !new:model.XVector_classifier
xvector_output_dim: !ref <xvector_output_dim>
num_labels: !ref <output_dim_spk>
enc_VAD: !new:model.VAD_linear_enc
input_dim: !ref <wav2vec2_output_dim>
epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
limit: !ref <number_of_epochs>
modules:
wav2vec2: !ref <wav2vec2>
interface_ER: !ref <interface_ER>
interface_SR: !ref <interface_SR>
interface_SV: !ref <interface_SV>
interface_VAD: !ref <interface_VAD>
enc_ER: !ref <enc_ER>
enc_SR: !ref <enc_SR>
enc_VAD: !ref <enc_VAD>
embedding_model_SV: !ref <embedding_model_SV>
classifier_SV: !ref <classifier_SV>
opt_class_SR: !name:torch.optim.Adadelta
lr: !ref <lr>
rho: 0.95
eps: 1.e-8
weight_decay: !ref <weight_decay_SR>
opt_class_ER: !name:torch.optim.Adadelta
lr: !ref <lr>
rho: 0.95
eps: 1.e-8
weight_decay: !ref <weight_decay_ER>
lr_annealing_ER: !new:speechbrain.nnet.schedulers.NewBobScheduler
initial_value: !ref <lr_ER>
improvement_threshold: 0.001
annealing_factor: 0.8
patient: 2
lr_annealing_SR: !new:speechbrain.nnet.schedulers.NewBobScheduler
initial_value: !ref <lr_SR>
improvement_threshold: 0.001
annealing_factor: 0.6
patient: 2
lr_annealing_SV: !new:speechbrain.nnet.schedulers.NewBobScheduler
initial_value: !ref <lr_SV>
improvement_threshold: 0.001
annealing_factor: 0.7
patient: 2
lr_annealing_VAD: !new:speechbrain.nnet.schedulers.NewBobScheduler
initial_value: !ref <lr_VAD>
improvement_threshold: 0.001
annealing_factor: 0.7
patient: 2
lr_annealing_w2v2: !new:speechbrain.nnet.schedulers.NewBobScheduler
initial_value: !ref <lr_w2v2>
improvement_threshold: 0.001
annealing_factor: 0.6
patient: 2
checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
checkpoints_dir: !ref <save_folder>
recoverables:
interface_ER: !ref <interface_ER>
interface_SR: !ref <interface_SR>
interface_SV: !ref <interface_SV>
interface_VAD: !ref <interface_VAD>
enc_ER: !ref <enc_ER>
enc_SR: !ref <enc_SR>
enc_VAD: !ref <enc_VAD>
embedding_model_SV: !ref <embedding_model_SV>
classifier_SV: !ref <classifier_SV>
scheduler_ER: !ref <lr_annealing_ER>
scheduler_SR: !ref <lr_annealing_SR>
scheduler_SV: !ref <lr_annealing_SV>
scheduler_VAD: !ref <lr_annealing_VAD>
scheduler_w2v2: !ref <lr_annealing_w2v2>
counter: !ref <epoch_counter>